Compare commits

...

13 Commits

Author SHA1 Message Date
Ethan Dickson
79bb0a4312 refactor(coderd/x/chatd): simplify queue stream reads 2026-04-10 10:34:44 +00:00
Ethan Dickson
6145135fa1 chore(coderd/x/chatd): fix formatting 2026-04-10 10:34:44 +00:00
Ethan Dickson
97e9a1ce4d chore(coderd/x/chatd): remove unused requireStreamQueueUpdateEvent helper 2026-04-10 10:34:44 +00:00
Ethan Dickson
ab93e58133 test(coderd/x/chatd): remove redundant singleflight tests
Drop three tests whose coverage is already provided by other tests:

- TestStreamFetchContextsUseWorkloadSpecificDeadlinePolicies: unit tests
  of trivial context helpers already exercised end-to-end by
  TestGetStreamChatMessagesUsesServerOwnedTimeoutDespiteLeaderDeadline.
- TestSubscribeQueueUpdateReloadsFreshSnapshotEvenDuringAnotherSubscriberInitialLoad:
  tests pre-existing subscribe behavior, not singleflight coalescing.
- TestSubscribeAuthorizedRefreshesStatusBeforeBufferedMessageParts:
  covered by the enterprise TestSubscribeRelaySnapshotDelivered which
  now passes a stale chat and asserts the refreshed status.
2026-04-10 10:34:44 +00:00
Ethan Dickson
52fea06766 fix(coderd/x/chatd): narrow stream singleflight to history reads 2026-04-10 10:34:44 +00:00
Ethan Dickson
a5f8f90aa8 fix(coderd/x/chatd): bound detached history fetches 2026-04-10 10:34:44 +00:00
Ethan Dickson
8658aa59a1 fix(coderd/x/chatd): fall back to stale chat on refresh failure 2026-04-10 10:34:43 +00:00
Ethan Dickson
d8665551f1 fix: refresh chat stream status and fetch policy 2026-04-10 10:34:43 +00:00
Ethan Dickson
eb08701107 fix(coderd/x/chatd): separate auth boundary from stream coalescing 2026-04-10 10:34:43 +00:00
Ethan Dickson
8efc89ad5b fix(coderd/x/chatd): authorize subscribers before shared stream fetches 2026-04-10 10:33:17 +00:00
Ethan Dickson
6e16297942 fix(coderd/x/chatd): eliminate scheduling races in singleflight tests
Replace the racy coalescing test (which required both goroutines
to enter DoChan before the leader's DB call completed — impossible
to guarantee without instrumenting singleflight internals) with a
deterministic clone-isolation test.

Fix the cancel-safety test by starting the leader first, waiting
for it to block in the DB mock, then calling the canceled waiter
synchronously with an already-canceled context. This eliminates
the scheduling race entirely — no Gosched, no barriers, no flakes.
2026-04-10 10:33:17 +00:00
Ethan Dickson
ab1f0306a6 test(coderd/x/chatd): trim singleflight tests to two representative cases
Keep only the messages coalescing test (verifies dedup + clone
isolation) and the canceled-waiter test (verifies WithoutCancel
safety). The deleted queued-messages and chat-state tests were
structurally identical.
2026-04-10 10:33:17 +00:00
Ethan Dickson
a9350b2ebe perf(coderd/x/chatd): coalesce concurrent subscriber DB reads with singleflight
When multiple users open the same chat stream, each subscriber
independently queries the database for the initial message snapshot,
queued messages, and chat status. The same duplication occurs when
pubsub notifications trigger message catch-up or queue refresh reads
across concurrent subscribers on the same replica.

Wrap the repeated DB reads in Subscribe() and the per-subscriber merge
goroutine with singleflight groups so that concurrent identical
requests collapse into a single database call. Each caller receives an
independent copy of the result to prevent shared-mutation bugs. The
singleflight fetch runs under context.WithoutCancel so that one
disconnecting subscriber does not cancel the in-flight read for the
others, matching the pattern already established in configcache.go.
2026-04-10 10:33:17 +00:00
4 changed files with 465 additions and 72 deletions

View File

@@ -2192,7 +2192,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
// Subscribe before accepting the WebSocket so that failures
// can still be reported as normal HTTP errors.
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
snapshot, events, cancelSub, ok := api.chatDaemon.SubscribeAuthorized(ctx, chat, r.Header, afterMessageID)
// Subscribe only fails today when the receiver is nil, which
// the chatDaemon == nil guard above already catches. This is
// defensive against future Subscribe failure modes.

View File

@@ -22,6 +22,7 @@ import (
"github.com/sqlc-dev/pqtype"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"tailscale.com/util/singleflight"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
@@ -96,6 +97,13 @@ const (
// cross-replica relay subscribers time to connect and
// snapshot the buffer before it is garbage-collected.
bufferRetainGracePeriod = 5 * time.Second
// chatStreamHistoryFetchTimeout bounds server-owned shared
// history reads. It is intentionally generous because initial
// and catch-up scans may be larger than control-path lookups.
chatStreamHistoryFetchTimeout = 30 * time.Second
// chatStreamControlFetchTimeout bounds subscriber-owned
// control-path DB reads when the caller has no deadline.
chatStreamControlFetchTimeout = 5 * time.Second
// DefaultMaxChatsPerAcquire is the maximum number of chats to
// acquire in a single processOnce call. Batching avoids
@@ -137,6 +145,11 @@ type Server struct {
// never contend with each other.
chatStreams sync.Map // uuid.UUID -> *chatStreamState
// streamMessageFetches coalesces concurrent chat stream durable
// history reads. It is not a cache: once a shared fetch
// completes, future reads hit the database again.
streamMessageFetches singleflight.Group[string, []database.ChatMessage]
// workspaceMCPToolsCache caches workspace MCP tool definitions
// per chat to avoid re-fetching on every turn. The cache is
// keyed by chat ID and invalidated when the agent changes.
@@ -3161,6 +3174,73 @@ func (p *Server) heartbeatTick(ctx context.Context) {
}
}
func cloneChatMessagesForStream(messages []database.ChatMessage) []database.ChatMessage {
cloned := slices.Clone(messages)
for i := range cloned {
cloned[i].Content.RawMessage = slices.Clone(cloned[i].Content.RawMessage)
}
return cloned
}
// streamSharedHistoryFetchContext detaches subscriber cancellation from a
// shared history fetch and runs it under a server-owned timeout budget.
// Shared work should not inherit the winner's request deadline.
func streamSharedHistoryFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.WithoutCancel(ctx), chatStreamHistoryFetchTimeout)
}
// streamSubscriberControlFetchContext keeps a control-path lookup tied to the
// requesting subscriber while applying a fallback timeout when the caller has
// no deadline.
func streamSubscriberControlFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
if _, ok := ctx.Deadline(); ok {
return ctx, func() {}
}
return context.WithTimeout(ctx, chatStreamControlFetchTimeout)
}
// getStreamChatMessages loads durable chat messages for an already-authorized
// subscriber. Subscribe() must validate the caller before this helper is used.
// The shared fetch intentionally runs as chatd so request identity and timeout
// policy come from chatd rather than whichever caller won singleflight.
func (p *Server) getStreamChatMessages(
ctx context.Context,
params database.GetChatMessagesByChatIDParams,
) ([]database.ChatMessage, error) {
messages, err := singleflightDoChan(
ctx,
&p.streamMessageFetches,
fmt.Sprintf("chat-messages:%s:after:%d", params.ChatID, params.AfterID),
func() ([]database.ChatMessage, error) {
fetchCtx, cancel := streamSharedHistoryFetchContext(ctx)
defer cancel()
//nolint:gocritic // SubscribeAuthorized already validated the
// caller; the shared singleflight fetch runs as chatd so the
// leader's request identity cannot affect other authorized waiters.
return p.db.GetChatMessagesByChatID(dbauthz.AsChatd(fetchCtx), params)
},
)
if err != nil {
return nil, err
}
return cloneChatMessagesForStream(messages), nil
}
func subscribeWithInitialError(chatID uuid.UUID, message string) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
bool,
) {
events := make(chan codersdk.ChatStreamEvent)
close(events)
return []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: &codersdk.ChatStreamError{Message: message},
}}, events, func() {}, true
}
func (p *Server) Subscribe(
ctx context.Context,
chatID uuid.UUID,
@@ -3175,9 +3255,40 @@ func (p *Server) Subscribe(
if p == nil {
return nil, nil, nil, false
}
if ctx == nil {
ctx = context.Background()
chat, err := p.db.GetChatByID(ctx, chatID)
if err != nil {
if dbauthz.IsNotAuthorizedError(err) {
return nil, nil, nil, false
}
p.logger.Warn(ctx, "failed to load chat for stream subscription",
slog.F("chat_id", chatID),
slog.Error(err),
)
return subscribeWithInitialError(chatID, "failed to load initial snapshot")
}
return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID)
}
// SubscribeAuthorized subscribes an already-authorized chat to merged stream
// updates. The passed chat row proves authorization, but SubscribeAuthorized
// still reloads the chat after the stream subscriptions are armed so the
// initial status and relay setup use fresh state.
func (p *Server) SubscribeAuthorized(
ctx context.Context,
chat database.Chat,
requestHeader http.Header,
afterMessageID int64,
) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
bool,
) {
if p == nil {
return nil, nil, nil, false
}
chatID := chat.ID
// Subscribe to the local stream for message_parts and same-replica
// persisted messages.
@@ -3241,6 +3352,34 @@ func (p *Server) Subscribe(
}
}
cancel := func() {
mergedCancel()
for _, cancelFn := range allCancels {
if cancelFn != nil {
cancelFn()
}
}
}
// Re-read the chat after the local/pubsub subscriptions are active so
// the initial status event and any enterprise relay setup use fresh
// state instead of the middleware-loaded row.
refreshCtx, refreshCancel := streamSubscriberControlFetchContext(ctx)
snapshotChat, err := func() (database.Chat, error) {
defer refreshCancel()
//nolint:gocritic // SubscribeAuthorized already validated the
// caller; this refresh only loads the latest status/worker for
// the already-authorized stream subscription.
return p.db.GetChatByID(dbauthz.AsChatd(refreshCtx), chatID)
}()
if err != nil {
p.logger.Warn(ctx, "failed to refresh chat for stream subscription; using stale state",
slog.F("chat_id", chatID),
slog.Error(err),
)
snapshotChat = chat
}
// Build initial snapshot synchronously. The pubsub subscription
// is already active so no notifications can be lost during this
// window.
@@ -3256,7 +3395,7 @@ func (p *Server) Subscribe(
// caller already has messages up to that ID (e.g. from the REST
// endpoint), so we only fetch newer ones to avoid sending
// duplicate data.
messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
messages, err := p.getStreamChatMessages(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: afterMessageID,
})
@@ -3281,8 +3420,12 @@ func (p *Server) Subscribe(
}
}
// Load initial queue.
queued, err := p.db.GetChatQueuedMessages(ctx, chatID)
// Load initial queue. Queue snapshots are intentionally not
// singleflighted because a chat-scoped key cannot distinguish the
// pre- and post-notification queue state.
queueCtx, queueCancel := streamSubscriberControlFetchContext(ctx)
queued, err := p.db.GetChatQueuedMessages(queueCtx, chatID)
queueCancel()
if err != nil {
p.logger.Error(ctx, "failed to load initial queued messages",
slog.Error(err),
@@ -3301,35 +3444,18 @@ func (p *Server) Subscribe(
})
}
// Get initial chat state to determine if we need a relay.
chat, chatErr := p.db.GetChatByID(ctx, chatID)
// Include the current chat status in the snapshot so the
// frontend can gate message_part processing correctly from
// the very first batch, without waiting for a separate REST
// query.
if chatErr != nil {
p.logger.Error(ctx, "failed to load initial chat state",
slog.Error(chatErr),
slog.F("chat_id", chatID),
)
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: &codersdk.ChatStreamError{Message: "failed to load initial snapshot"},
})
} else {
statusEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeStatus,
ChatID: chatID,
Status: &codersdk.ChatStreamStatus{
Status: codersdk.ChatStatus(chat.Status),
},
}
// Prepend so the frontend sees the status before any
// message_part events.
initialSnapshot = append([]codersdk.ChatStreamEvent{statusEvent}, initialSnapshot...)
// Include the current chat status in the snapshot so the frontend can gate
// message_part processing correctly from the very first batch, without
// waiting for a separate REST query.
statusEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeStatus,
ChatID: chatID,
Status: &codersdk.ChatStreamStatus{
Status: codersdk.ChatStatus(snapshotChat.Status),
},
}
// Prepend so the frontend sees the status before any message_part events.
initialSnapshot = append([]codersdk.ChatStreamEvent{statusEvent}, initialSnapshot...)
// Track the highest durable message ID delivered to this subscriber,
// whether it came from the initial DB snapshot, the same-replica local
@@ -3339,18 +3465,17 @@ func (p *Server) Subscribe(
lastMessageID = messages[len(messages)-1].ID
}
// When an enterprise SubscribeFn is provided and the chat
// lookup succeeded, call it to get relay events (message_parts
// from remote replicas). OSS now owns pubsub subscription,
// message catch-up, queue updates, and status forwarding;
// enterprise only manages relay dialing.
// When an enterprise SubscribeFn is provided, call it to get relay events
// (message_parts from remote replicas). OSS owns pubsub subscription,
// message catch-up, queue updates, and status forwarding; enterprise only
// manages relay dialing.
var relayEvents <-chan codersdk.ChatStreamEvent
var statusNotifications chan StatusNotification
if p.subscribeFn != nil && chatErr == nil {
if p.subscribeFn != nil {
statusNotifications = make(chan StatusNotification, 10)
relayEvents = p.subscribeFn(mergedCtx, SubscribeFnParams{
ChatID: chatID,
Chat: chat,
Chat: snapshotChat,
WorkerID: p.workerID,
StatusNotifications: statusNotifications,
RequestHeader: requestHeader,
@@ -3407,7 +3532,7 @@ func (p *Server) Subscribe(
}
lastMessageID = event.Message.ID
}
} else if newMessages, msgErr := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
} else if newMessages, msgErr := p.getStreamChatMessages(mergedCtx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: lastMessageID,
}); msgErr != nil {
@@ -3495,7 +3620,9 @@ func (p *Server) Subscribe(
}
}
if notify.QueueUpdate {
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(mergedCtx, chatID)
queueCtx, queueCancel := streamSubscriberControlFetchContext(mergedCtx)
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(queueCtx, chatID)
queueCancel()
if queueErr != nil {
p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification",
slog.F("chat_id", chatID),
@@ -3571,14 +3698,6 @@ func (p *Server) Subscribe(
}
}()
cancel := func() {
mergedCancel()
for _, cancelFn := range allCancels {
if cancelFn != nil {
cancelFn()
}
}
}
return initialSnapshot, mergedEvents, cancel, true
}

View File

@@ -18,6 +18,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbmock"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
@@ -1293,14 +1294,14 @@ func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -1336,14 +1337,14 @@ func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T)
ChatID: chatID,
Role: codersdk.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -1387,14 +1388,14 @@ func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 1,
@@ -1436,14 +1437,14 @@ func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
@@ -1473,14 +1474,14 @@ func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -1517,14 +1518,14 @@ func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) {
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -1557,14 +1558,14 @@ func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -1581,6 +1582,274 @@ func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeAuthorizedFallsBackToStaleRowWhenRefreshFails(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := newSubscribeTestServer(t, db)
chatID := uuid.New()
staleChat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
state := server.getOrCreateStreamState(chatID)
state.mu.Lock()
state.buffer = []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
ChatID: chatID,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("thinking"),
},
}}
state.mu.Unlock()
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{}, xerrors.New("refresh failed")),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
)
initialSnapshot, events, cancel, ok := server.SubscribeAuthorized(ctx, staleChat, nil, 0)
require.True(t, ok)
defer cancel()
require.Len(t, initialSnapshot, 2)
require.Equal(t, codersdk.ChatStreamEventTypeStatus, initialSnapshot[0].Type)
require.NotNil(t, initialSnapshot[0].Status)
require.Equal(t, codersdk.ChatStatusPending, initialSnapshot[0].Status.Status)
require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, initialSnapshot[1].Type)
require.NotNil(t, initialSnapshot[1].MessagePart)
require.Equal(t, "thinking", initialSnapshot[1].MessagePart.Part.Text)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
// TestGetStreamChatMessagesReturnsIndependentClones verifies that
// successive calls to getStreamChatMessages return independent
// copies so that one caller mutating a result does not corrupt
// another caller's view. This is the key correctness property
// of the singleflight wrappers — the coalescing itself is
// already tested by tailscale.com/util/singleflight.
func TestGetStreamChatMessagesReturnsIndependentClones(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := newSubscribeTestServer(t, db)
chatID := uuid.New()
params := database.GetChatMessagesByChatIDParams{ChatID: chatID, AfterID: 0}
message := database.ChatMessage{
ID: 1,
ChatID: chatID,
Content: pqtype.NullRawMessage{
RawMessage: []byte(`"hello"`),
Valid: true,
},
}
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), params).
Return([]database.ChatMessage{message}, nil).AnyTimes()
first, err := server.getStreamChatMessages(ctx, params)
require.NoError(t, err)
second, err := server.getStreamChatMessages(ctx, params)
require.NoError(t, err)
require.Len(t, first, 1)
require.Len(t, second, 1)
// Mutate first; second must be unaffected.
first[0].Content.RawMessage[0] = 'x'
require.Equal(t, `"hello"`, string(second[0].Content.RawMessage))
}
func TestGetStreamChatMessagesCanceledWaiterDoesNotPoisonOthers(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := newSubscribeTestServer(t, db)
chatID := uuid.New()
params := database.GetChatMessagesByChatIDParams{ChatID: chatID, AfterID: 0}
message := database.ChatMessage{ID: 1, ChatID: chatID}
started := make(chan struct{})
release := make(chan struct{})
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), params).DoAndReturn(func(context.Context, database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
select {
case <-started:
default:
close(started)
}
<-release
return []database.ChatMessage{message}, nil
}).AnyTimes()
// Start the leader goroutine and wait for it to block inside
// the DB mock so the singleflight key is definitely in-flight.
survivorCh := make(chan struct {
messages []database.ChatMessage
err error
}, 1)
go func() {
msgs, err := server.getStreamChatMessages(ctx, params)
survivorCh <- struct {
messages []database.ChatMessage
err error
}{messages: msgs, err: err}
}()
waitForSignal(t, started)
// Now cancel a second caller while the leader is still
// blocked. Because singleflightDoChan selects on ctx.Done,
// the canceled caller must return immediately with
// context.Canceled without waiting for the DB call.
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
_, err := server.getStreamChatMessages(cancelCtx, params)
require.ErrorIs(t, err, context.Canceled)
// Release the leader; it must succeed independently.
close(release)
select {
case result := <-survivorCh:
require.NoError(t, result.err)
require.Len(t, result.messages, 1)
case <-time.After(testutil.WaitShort):
t.Fatal("surviving waiter did not return")
}
}
func TestGetStreamChatMessagesUsesServerOwnedTimeoutDespiteLeaderDeadline(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := newSubscribeTestServer(t, db)
chatID := uuid.New()
params := database.GetChatMessagesByChatIDParams{ChatID: chatID, AfterID: 0}
message := database.ChatMessage{ID: 1, ChatID: chatID}
fetchCtxCh := make(chan context.Context, 1)
release := make(chan struct{})
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), params).DoAndReturn(func(fetchCtx context.Context, _ database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
fetchCtxCh <- fetchCtx
<-release
return []database.ChatMessage{message}, nil
}).Times(1)
started := time.Now()
leaderCtx, leaderCancel := context.WithTimeout(context.Background(), testutil.IntervalFast)
defer leaderCancel()
leaderErrCh := make(chan error, 1)
go func() {
_, err := server.getStreamChatMessages(leaderCtx, params)
leaderErrCh <- err
}()
var fetchCtx context.Context
select {
case fetchCtx = <-fetchCtxCh:
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for shared history fetch to start")
}
deadline, ok := fetchCtx.Deadline()
require.True(t, ok)
require.WithinDuration(t, started.Add(chatStreamHistoryFetchTimeout), deadline, 500*time.Millisecond)
followerCh := make(chan struct {
messages []database.ChatMessage
err error
}, 1)
go func() {
msgs, err := server.getStreamChatMessages(context.Background(), params)
followerCh <- struct {
messages []database.ChatMessage
err error
}{messages: msgs, err: err}
}()
<-leaderCtx.Done()
require.NoError(t, fetchCtx.Err())
close(release)
select {
case err := <-leaderErrCh:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for leader to observe its deadline")
}
select {
case result := <-followerCh:
require.NoError(t, result.err)
require.Len(t, result.messages, 1)
require.Equal(t, int64(1), result.messages[0].ID)
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for follower to receive shared fetch result")
}
}
func TestSubscribeRejectsUnauthorizedCallerBeforeSharedFetches(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := newSubscribeTestServer(t, db)
chatID := uuid.New()
db.EXPECT().GetChatByID(gomock.Any(), chatID).
Return(database.Chat{}, dbauthz.NotAuthorizedError{Err: xerrors.New("not authorized")})
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.False(t, ok)
require.Nil(t, snapshot)
require.Nil(t, events)
require.Nil(t, cancel)
_, exists := server.chatStreams.Load(chatID)
require.False(t, exists)
}
func TestSubscribeSurfacesTransientLookupFailureAsInitialError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := newSubscribeTestServer(t, db)
chatID := uuid.New()
db.EXPECT().GetChatByID(gomock.Any(), chatID).
Return(database.Chat{}, xerrors.New("transient lookup failure"))
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
require.NotNil(t, cancel)
require.Len(t, snapshot, 1)
require.Equal(t, codersdk.ChatStreamEventTypeError, snapshot[0].Type)
require.Equal(t, chatID, snapshot[0].ChatID)
require.Equal(t, "failed to load initial snapshot", snapshot[0].Error.Message)
_, open := <-events
require.False(t, open)
_, exists := server.chatStreams.Load(chatID)
require.False(t, exists)
}
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
t.Helper()

View File

@@ -430,8 +430,13 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
user, model := seedChatDependencies(ctx, t, db)
chat := seedRemoteRunningChat(ctx, t, db, user, model, workerID, "relay-snapshot")
staleChat := chat
staleChat.Status = database.ChatStatusWaiting
staleChat.WorkerID = uuid.NullUUID{}
staleChat.StartedAt = sql.NullTime{}
staleChat.HeartbeatAt = sql.NullTime{}
initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
initialSnapshot, events, cancel, ok := subscriber.SubscribeAuthorized(ctx, staleChat, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
@@ -455,15 +460,15 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
require.Equal(t, []string{"snap-one", "snap-two", "live-part"}, receivedTexts)
// The initial snapshot should still contain the status event
// from the OSS preamble.
var hasStatus bool
// The initial snapshot should contain the refreshed running status,
// not the stale waiting status passed into SubscribeAuthorized.
var snapshotStatus codersdk.ChatStatus
for _, event := range initialSnapshot {
if event.Type == codersdk.ChatStreamEventTypeStatus {
hasStatus = true
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
snapshotStatus = event.Status.Status
}
}
require.True(t, hasStatus, "initial snapshot should contain status event")
require.Equal(t, codersdk.ChatStatusRunning, snapshotStatus)
}
func TestSubscribeRetryEventAcrossInstances(t *testing.T) {