Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 18a94b4787 | |||
| 2dbb90eee9 |
+41
-16
@@ -146,11 +146,26 @@ type SubscribeFnParams struct {
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
// chatStreamSubscription wraps a subscriber's event channel with
|
||||
// an overflow signal. When publishToStream cannot deliver an event
|
||||
// because the channel buffer is full, it signals overflow so the
|
||||
// consumer (the merge goroutine in Subscribe) can terminate and
|
||||
// let the client reconnect with a buffer replay.
|
||||
type chatStreamSubscription struct {
|
||||
events chan codersdk.ChatStreamEvent
|
||||
overflow chan struct{}
|
||||
overflowOnce sync.Once
|
||||
}
|
||||
|
||||
func (s *chatStreamSubscription) signalOverflow() {
|
||||
s.overflowOnce.Do(func() { close(s.overflow) })
|
||||
}
|
||||
|
||||
type chatStreamState struct {
|
||||
mu sync.Mutex
|
||||
buffer []codersdk.ChatStreamEvent
|
||||
buffering bool
|
||||
subscribers map[uuid.UUID]chan codersdk.ChatStreamEvent
|
||||
subscribers map[uuid.UUID]*chatStreamSubscription
|
||||
}
|
||||
|
||||
// MaxQueueSize is the maximum number of queued user messages per chat.
|
||||
@@ -1117,18 +1132,19 @@ func (p *Server) publishToStream(chatID uuid.UUID, event codersdk.ChatStreamEven
|
||||
}
|
||||
state.buffer = append(state.buffer, event)
|
||||
}
|
||||
subscribers := make([]chan codersdk.ChatStreamEvent, 0, len(state.subscribers))
|
||||
for _, ch := range state.subscribers {
|
||||
subscribers = append(subscribers, ch)
|
||||
subscribers := make([]*chatStreamSubscription, 0, len(state.subscribers))
|
||||
for _, sub := range state.subscribers {
|
||||
subscribers = append(subscribers, sub)
|
||||
}
|
||||
state.mu.Unlock()
|
||||
|
||||
for _, ch := range subscribers {
|
||||
for _, sub := range subscribers {
|
||||
select {
|
||||
case ch <- event:
|
||||
case sub.events <- event:
|
||||
default:
|
||||
p.logger.Warn(context.Background(), "dropping chat stream event",
|
||||
p.logger.Warn(context.Background(), "dropping chat stream event, signaling overflow",
|
||||
slog.F("chat_id", chatID), slog.F("type", event.Type))
|
||||
sub.signalOverflow()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1143,29 +1159,33 @@ func (p *Server) publishToStream(chatID uuid.UUID, event codersdk.ChatStreamEven
|
||||
func (p *Server) subscribeToStream(chatID uuid.UUID) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
<-chan struct{},
|
||||
func(),
|
||||
) {
|
||||
state := p.getOrCreateStreamState(chatID)
|
||||
state.mu.Lock()
|
||||
snapshot := append([]codersdk.ChatStreamEvent(nil), state.buffer...)
|
||||
id := uuid.New()
|
||||
ch := make(chan codersdk.ChatStreamEvent, 128)
|
||||
state.subscribers[id] = ch
|
||||
sub := &chatStreamSubscription{
|
||||
events: make(chan codersdk.ChatStreamEvent, 128),
|
||||
overflow: make(chan struct{}),
|
||||
}
|
||||
state.subscribers[id] = sub
|
||||
state.mu.Unlock()
|
||||
|
||||
cancel := func() {
|
||||
state.mu.Lock()
|
||||
// Remove the subscriber but do not close the channel.
|
||||
// publishToStream copies subscriber references under
|
||||
// the per-chat lock then sends outside; closing here
|
||||
// races with that send and can panic. The channel
|
||||
// Remove the subscriber but do not close the events
|
||||
// channel. publishToStream copies subscriber references
|
||||
// under the per-chat lock then sends outside; closing
|
||||
// here races with that send and can panic. The channel
|
||||
// becomes unreachable once removed and will be GC'd.
|
||||
delete(state.subscribers, id)
|
||||
p.cleanupStreamIfIdle(chatID, state)
|
||||
state.mu.Unlock()
|
||||
}
|
||||
|
||||
return snapshot, ch, cancel
|
||||
return snapshot, sub.events, sub.overflow, cancel
|
||||
}
|
||||
|
||||
// getOrCreateStreamState returns the per-chat stream state,
|
||||
@@ -1178,7 +1198,7 @@ func (p *Server) getOrCreateStreamState(chatID uuid.UUID) *chatStreamState {
|
||||
return state
|
||||
}
|
||||
val, _ := p.chatStreams.LoadOrStore(chatID, &chatStreamState{
|
||||
subscribers: make(map[uuid.UUID]chan codersdk.ChatStreamEvent),
|
||||
subscribers: make(map[uuid.UUID]*chatStreamSubscription),
|
||||
})
|
||||
state, _ := val.(*chatStreamState)
|
||||
return state
|
||||
@@ -1212,7 +1232,7 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
|
||||
// Subscribe to local stream for message_parts (ephemeral).
|
||||
localSnapshot, localParts, localCancel := p.subscribeToStream(chatID)
|
||||
localSnapshot, localParts, localOverflow, localCancel := p.subscribeToStream(chatID)
|
||||
|
||||
// Merge all event sources.
|
||||
mergedCtx, mergedCancel := context.WithCancel(ctx)
|
||||
@@ -1407,6 +1427,11 @@ func (p *Server) Subscribe(
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case <-localOverflow:
|
||||
p.logger.Warn(mergedCtx, "local subscriber overflow, closing stream for reconnect",
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return
|
||||
case psErr := <-errCh:
|
||||
p.logger.Error(mergedCtx, "chat stream pubsub error",
|
||||
slog.F("chat_id", chatID),
|
||||
|
||||
@@ -6,9 +6,13 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
||||
@@ -84,3 +88,212 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
|
||||
require.ErrorContains(t, err, loadErr.Error())
|
||||
require.Equal(t, chat, refreshed)
|
||||
}
|
||||
|
||||
func TestPublishToStreamOverflowTriggersOnFullBuffer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
srv := &Server{
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
}
|
||||
|
||||
state := srv.getOrCreateStreamState(chatID)
|
||||
state.mu.Lock()
|
||||
state.buffering = true
|
||||
state.mu.Unlock()
|
||||
|
||||
_, ch, overflow, cancel := srv.subscribeToStream(chatID)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Publish more events than the subscriber channel buffer (128)
|
||||
// can hold, without consuming from the channel.
|
||||
const totalPublished = 200
|
||||
for range totalPublished {
|
||||
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
ChatID: chatID,
|
||||
})
|
||||
}
|
||||
|
||||
// The overflow channel should be signaled because the
|
||||
// subscriber's buffer filled up.
|
||||
select {
|
||||
case <-overflow:
|
||||
default:
|
||||
t.Fatal("expected overflow signal when subscriber buffer is full")
|
||||
}
|
||||
|
||||
// The subscriber channel should have exactly 128 buffered
|
||||
// events (its capacity) before overflow was triggered.
|
||||
var received int
|
||||
drain:
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
received++
|
||||
default:
|
||||
break drain
|
||||
}
|
||||
}
|
||||
require.Equal(t, 128, received)
|
||||
}
|
||||
|
||||
func TestPublishToStreamSuccessfulDelivery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
srv := &Server{
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
}
|
||||
|
||||
state := srv.getOrCreateStreamState(chatID)
|
||||
state.mu.Lock()
|
||||
state.buffering = true
|
||||
state.mu.Unlock()
|
||||
|
||||
_, ch, overflow, cancel := srv.subscribeToStream(chatID)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Publish events while actively consuming. The subscriber
|
||||
// should never overflow because the channel never fills.
|
||||
const totalPublished = 50
|
||||
for range totalPublished {
|
||||
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
ChatID: chatID,
|
||||
})
|
||||
// Immediately consume the event.
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
t.Fatal("expected event to be available immediately")
|
||||
}
|
||||
}
|
||||
|
||||
// Overflow must not have been signaled.
|
||||
select {
|
||||
case <-overflow:
|
||||
t.Fatal("overflow signaled unexpectedly")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Removed: TestPublishToStreamOverflowSignalsSubscriber was a
|
||||
// duplicate of TestPublishToStreamOverflowTriggersOnFullBuffer.
|
||||
|
||||
func TestPublishToStreamBufferFullDropsOldest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
srv := &Server{
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
}
|
||||
|
||||
state := srv.getOrCreateStreamState(chatID)
|
||||
state.mu.Lock()
|
||||
state.buffering = true
|
||||
state.mu.Unlock()
|
||||
|
||||
// Publish more than maxStreamBufferSize events with no
|
||||
// subscribers to exercise the buffer-full oldest-drop path.
|
||||
for i := range maxStreamBufferSize + 100 {
|
||||
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
ChatID: chatID,
|
||||
})
|
||||
_ = i
|
||||
}
|
||||
|
||||
state.mu.Lock()
|
||||
bufLen := len(state.buffer)
|
||||
state.mu.Unlock()
|
||||
|
||||
// The buffer should be capped at maxStreamBufferSize.
|
||||
require.Equal(t, maxStreamBufferSize, bufLen)
|
||||
}
|
||||
|
||||
func TestPublishToStreamNotBufferingEarlyReturn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
srv := &Server{
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
}
|
||||
|
||||
// Do NOT enable buffering. The message_part event should hit
|
||||
// the early return path and not be buffered.
|
||||
state := srv.getOrCreateStreamState(chatID)
|
||||
|
||||
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
ChatID: chatID,
|
||||
})
|
||||
|
||||
state.mu.Lock()
|
||||
bufLen := len(state.buffer)
|
||||
state.mu.Unlock()
|
||||
|
||||
require.Equal(t, 0, bufLen)
|
||||
|
||||
// The stream state should have been cleaned up since there
|
||||
// are no subscribers and buffering is off.
|
||||
_, loaded := srv.chatStreams.Load(chatID)
|
||||
require.False(t, loaded, "stream state should be cleaned up when idle")
|
||||
}
|
||||
|
||||
func TestSubscribeMergeDetectsOverflow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockDB := dbmock.NewMockStore(ctrl)
|
||||
|
||||
// Subscribe calls GetChatMessagesByChatID, GetChatQueuedMessages,
|
||||
// and GetChatByID during snapshot construction.
|
||||
mockDB.EXPECT().GetChatMessagesByChatID(gomock.Any(), gomock.Any()).
|
||||
Return(nil, nil).AnyTimes()
|
||||
mockDB.EXPECT().GetChatQueuedMessages(gomock.Any(), gomock.Any()).
|
||||
Return(nil, nil).AnyTimes()
|
||||
mockDB.EXPECT().GetChatByID(gomock.Any(), gomock.Any()).
|
||||
Return(database.Chat{ID: chatID, Status: database.ChatStatusPending}, nil).AnyTimes()
|
||||
|
||||
srv := &Server{
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
db: mockDB,
|
||||
}
|
||||
|
||||
// Enable buffering so message_part events are accepted.
|
||||
state := srv.getOrCreateStreamState(chatID)
|
||||
state.mu.Lock()
|
||||
state.buffering = true
|
||||
state.mu.Unlock()
|
||||
|
||||
// Use the full Subscribe path. pubsub is nil so the merge
|
||||
// goroutine will use local-only forwarding.
|
||||
_, mergedEvents, cancel, ok := srv.Subscribe(
|
||||
t.Context(), chatID, nil, 0,
|
||||
)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Publish enough events to overflow the local subscriber
|
||||
// channel (buffer=128). The merge goroutine reads from the
|
||||
// local channel, but we publish fast enough to fill it.
|
||||
// Since mergedEvents also has a buffer of 128, we need to
|
||||
// saturate both channels. Publish 300 events to be safe.
|
||||
for range 300 {
|
||||
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
ChatID: chatID,
|
||||
})
|
||||
}
|
||||
|
||||
// The merged events channel should close because the merge
|
||||
// goroutine detects the overflow signal and returns.
|
||||
for ev := range mergedEvents {
|
||||
_ = ev
|
||||
}
|
||||
|
||||
// If we reach here, mergedEvents was closed. That confirms
|
||||
// the merge goroutine detected the overflow and terminated.
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user