Compare commits

...

2 Commits

Author SHA1 Message Date
Mathias Fredriksson 18a94b4787 fix(chatd): detect subscriber overflow and signal reconnect
publishToStream silently dropped events via non-blocking send when a
subscriber's 128-slot channel was full. Replace bare channels with a
subscription struct that includes an overflow signal. When events
cannot be delivered, signal overflow so the merge goroutine terminates
and the client reconnects with a buffer replay.

Add tests covering the subscriber send path, overflow detection, and
merge goroutine integration.
2026-03-16 14:31:49 +00:00
Mathias Fredriksson 2dbb90eee9 test(chatd): assert publishToStream delivers all events to subscribers
publishToStream uses a non-blocking send to subscriber channels
(buffer=128). When a subscriber can't consume fast enough, events
are silently dropped with no sequence numbering, gap detection,
or re-delivery mechanism. This test publishes 200 events without
consuming and verifies all reach the subscriber. It fails because
only 128 of 200 survive.
2026-03-13 22:35:55 +00:00
2 changed files with 254 additions and 16 deletions
+41 -16
View File
@@ -146,11 +146,26 @@ type SubscribeFnParams struct {
Logger slog.Logger
}
// chatStreamSubscription wraps a subscriber's event channel with
// an overflow signal. When publishToStream cannot deliver an event
// because the channel buffer is full, it signals overflow so the
// consumer (the merge goroutine in Subscribe) can terminate and
// let the client reconnect with a buffer replay.
type chatStreamSubscription struct {
events chan codersdk.ChatStreamEvent
overflow chan struct{}
overflowOnce sync.Once
}
func (s *chatStreamSubscription) signalOverflow() {
s.overflowOnce.Do(func() { close(s.overflow) })
}
type chatStreamState struct {
mu sync.Mutex
buffer []codersdk.ChatStreamEvent
buffering bool
subscribers map[uuid.UUID]chan codersdk.ChatStreamEvent
subscribers map[uuid.UUID]*chatStreamSubscription
}
// MaxQueueSize is the maximum number of queued user messages per chat.
@@ -1117,18 +1132,19 @@ func (p *Server) publishToStream(chatID uuid.UUID, event codersdk.ChatStreamEven
}
state.buffer = append(state.buffer, event)
}
subscribers := make([]chan codersdk.ChatStreamEvent, 0, len(state.subscribers))
for _, ch := range state.subscribers {
subscribers = append(subscribers, ch)
subscribers := make([]*chatStreamSubscription, 0, len(state.subscribers))
for _, sub := range state.subscribers {
subscribers = append(subscribers, sub)
}
state.mu.Unlock()
for _, ch := range subscribers {
for _, sub := range subscribers {
select {
case ch <- event:
case sub.events <- event:
default:
p.logger.Warn(context.Background(), "dropping chat stream event",
p.logger.Warn(context.Background(), "dropping chat stream event, signaling overflow",
slog.F("chat_id", chatID), slog.F("type", event.Type))
sub.signalOverflow()
}
}
@@ -1143,29 +1159,33 @@ func (p *Server) publishToStream(chatID uuid.UUID, event codersdk.ChatStreamEven
func (p *Server) subscribeToStream(chatID uuid.UUID) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
<-chan struct{},
func(),
) {
state := p.getOrCreateStreamState(chatID)
state.mu.Lock()
snapshot := append([]codersdk.ChatStreamEvent(nil), state.buffer...)
id := uuid.New()
ch := make(chan codersdk.ChatStreamEvent, 128)
state.subscribers[id] = ch
sub := &chatStreamSubscription{
events: make(chan codersdk.ChatStreamEvent, 128),
overflow: make(chan struct{}),
}
state.subscribers[id] = sub
state.mu.Unlock()
cancel := func() {
state.mu.Lock()
// Remove the subscriber but do not close the channel.
// publishToStream copies subscriber references under
// the per-chat lock then sends outside; closing here
// races with that send and can panic. The channel
// Remove the subscriber but do not close the events
// channel. publishToStream copies subscriber references
// under the per-chat lock then sends outside; closing
// here races with that send and can panic. The channel
// becomes unreachable once removed and will be GC'd.
delete(state.subscribers, id)
p.cleanupStreamIfIdle(chatID, state)
state.mu.Unlock()
}
return snapshot, ch, cancel
return snapshot, sub.events, sub.overflow, cancel
}
// getOrCreateStreamState returns the per-chat stream state,
@@ -1178,7 +1198,7 @@ func (p *Server) getOrCreateStreamState(chatID uuid.UUID) *chatStreamState {
return state
}
val, _ := p.chatStreams.LoadOrStore(chatID, &chatStreamState{
subscribers: make(map[uuid.UUID]chan codersdk.ChatStreamEvent),
subscribers: make(map[uuid.UUID]*chatStreamSubscription),
})
state, _ := val.(*chatStreamState)
return state
@@ -1212,7 +1232,7 @@ func (p *Server) Subscribe(
}
// Subscribe to local stream for message_parts (ephemeral).
localSnapshot, localParts, localCancel := p.subscribeToStream(chatID)
localSnapshot, localParts, localOverflow, localCancel := p.subscribeToStream(chatID)
// Merge all event sources.
mergedCtx, mergedCancel := context.WithCancel(ctx)
@@ -1407,6 +1427,11 @@ func (p *Server) Subscribe(
select {
case <-mergedCtx.Done():
return
case <-localOverflow:
p.logger.Warn(mergedCtx, "local subscriber overflow, closing stream for reconnect",
slog.F("chat_id", chatID),
)
return
case psErr := <-errCh:
p.logger.Error(mergedCtx, "chat stream pubsub error",
slog.F("chat_id", chatID),
+213
View File
@@ -6,9 +6,13 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/codersdk"
)
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
@@ -84,3 +88,212 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
func TestPublishToStreamOverflowTriggersOnFullBuffer(t *testing.T) {
t.Parallel()
chatID := uuid.New()
srv := &Server{
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
}
state := srv.getOrCreateStreamState(chatID)
state.mu.Lock()
state.buffering = true
state.mu.Unlock()
_, ch, overflow, cancel := srv.subscribeToStream(chatID)
t.Cleanup(cancel)
// Publish more events than the subscriber channel buffer (128)
// can hold, without consuming from the channel.
const totalPublished = 200
for range totalPublished {
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
ChatID: chatID,
})
}
// The overflow channel should be signaled because the
// subscriber's buffer filled up.
select {
case <-overflow:
default:
t.Fatal("expected overflow signal when subscriber buffer is full")
}
// The subscriber channel should have exactly 128 buffered
// events (its capacity) before overflow was triggered.
var received int
drain:
for {
select {
case <-ch:
received++
default:
break drain
}
}
require.Equal(t, 128, received)
}
func TestPublishToStreamSuccessfulDelivery(t *testing.T) {
t.Parallel()
chatID := uuid.New()
srv := &Server{
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
}
state := srv.getOrCreateStreamState(chatID)
state.mu.Lock()
state.buffering = true
state.mu.Unlock()
_, ch, overflow, cancel := srv.subscribeToStream(chatID)
t.Cleanup(cancel)
// Publish events while actively consuming. The subscriber
// should never overflow because the channel never fills.
const totalPublished = 50
for range totalPublished {
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
ChatID: chatID,
})
// Immediately consume the event.
select {
case <-ch:
default:
t.Fatal("expected event to be available immediately")
}
}
// Overflow must not have been signaled.
select {
case <-overflow:
t.Fatal("overflow signaled unexpectedly")
default:
}
}
// Removed: TestPublishToStreamOverflowSignalsSubscriber was a
// duplicate of TestPublishToStreamOverflowTriggersOnFullBuffer.
func TestPublishToStreamBufferFullDropsOldest(t *testing.T) {
t.Parallel()
chatID := uuid.New()
srv := &Server{
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
}
state := srv.getOrCreateStreamState(chatID)
state.mu.Lock()
state.buffering = true
state.mu.Unlock()
// Publish more than maxStreamBufferSize events with no
// subscribers to exercise the buffer-full oldest-drop path.
for i := range maxStreamBufferSize + 100 {
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
ChatID: chatID,
})
_ = i
}
state.mu.Lock()
bufLen := len(state.buffer)
state.mu.Unlock()
// The buffer should be capped at maxStreamBufferSize.
require.Equal(t, maxStreamBufferSize, bufLen)
}
func TestPublishToStreamNotBufferingEarlyReturn(t *testing.T) {
t.Parallel()
chatID := uuid.New()
srv := &Server{
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
}
// Do NOT enable buffering. The message_part event should hit
// the early return path and not be buffered.
state := srv.getOrCreateStreamState(chatID)
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
ChatID: chatID,
})
state.mu.Lock()
bufLen := len(state.buffer)
state.mu.Unlock()
require.Equal(t, 0, bufLen)
// The stream state should have been cleaned up since there
// are no subscribers and buffering is off.
_, loaded := srv.chatStreams.Load(chatID)
require.False(t, loaded, "stream state should be cleaned up when idle")
}
func TestSubscribeMergeDetectsOverflow(t *testing.T) {
t.Parallel()
chatID := uuid.New()
ctrl := gomock.NewController(t)
mockDB := dbmock.NewMockStore(ctrl)
// Subscribe calls GetChatMessagesByChatID, GetChatQueuedMessages,
// and GetChatByID during snapshot construction.
mockDB.EXPECT().GetChatMessagesByChatID(gomock.Any(), gomock.Any()).
Return(nil, nil).AnyTimes()
mockDB.EXPECT().GetChatQueuedMessages(gomock.Any(), gomock.Any()).
Return(nil, nil).AnyTimes()
mockDB.EXPECT().GetChatByID(gomock.Any(), gomock.Any()).
Return(database.Chat{ID: chatID, Status: database.ChatStatusPending}, nil).AnyTimes()
srv := &Server{
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
db: mockDB,
}
// Enable buffering so message_part events are accepted.
state := srv.getOrCreateStreamState(chatID)
state.mu.Lock()
state.buffering = true
state.mu.Unlock()
// Use the full Subscribe path. pubsub is nil so the merge
// goroutine will use local-only forwarding.
_, mergedEvents, cancel, ok := srv.Subscribe(
t.Context(), chatID, nil, 0,
)
require.True(t, ok)
t.Cleanup(cancel)
// Publish enough events to overflow the local subscriber
// channel (buffer=128). The merge goroutine reads from the
// local channel, but we publish fast enough to fill it.
// Since mergedEvents also has a buffer of 128, we need to
// saturate both channels. Publish 300 events to be safe.
for range 300 {
srv.publishToStream(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
ChatID: chatID,
})
}
// The merged events channel should close because the merge
// goroutine detects the overflow signal and returns.
for ev := range mergedEvents {
_ = ev
}
// If we reach here, mergedEvents was closed. That confirms
// the merge goroutine detected the overflow and terminated.
}