Compare commits
13 Commits
main
...
perf/chatd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79bb0a4312 | ||
|
|
6145135fa1 | ||
|
|
97e9a1ce4d | ||
|
|
ab93e58133 | ||
|
|
52fea06766 | ||
|
|
a5f8f90aa8 | ||
|
|
8658aa59a1 | ||
|
|
d8665551f1 | ||
|
|
eb08701107 | ||
|
|
8efc89ad5b | ||
|
|
6e16297942 | ||
|
|
ab1f0306a6 | ||
|
|
a9350b2ebe |
@@ -2192,7 +2192,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Subscribe before accepting the WebSocket so that failures
|
||||
// can still be reported as normal HTTP errors.
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.SubscribeAuthorized(ctx, chat, r.Header, afterMessageID)
|
||||
// Subscribe only fails today when the receiver is nil, which
|
||||
// the chatDaemon == nil guard above already catches. This is
|
||||
// defensive against future Subscribe failure modes.
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/util/singleflight"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -96,6 +97,13 @@ const (
|
||||
// cross-replica relay subscribers time to connect and
|
||||
// snapshot the buffer before it is garbage-collected.
|
||||
bufferRetainGracePeriod = 5 * time.Second
|
||||
// chatStreamHistoryFetchTimeout bounds server-owned shared
|
||||
// history reads. It is intentionally generous because initial
|
||||
// and catch-up scans may be larger than control-path lookups.
|
||||
chatStreamHistoryFetchTimeout = 30 * time.Second
|
||||
// chatStreamControlFetchTimeout bounds subscriber-owned
|
||||
// control-path DB reads when the caller has no deadline.
|
||||
chatStreamControlFetchTimeout = 5 * time.Second
|
||||
|
||||
// DefaultMaxChatsPerAcquire is the maximum number of chats to
|
||||
// acquire in a single processOnce call. Batching avoids
|
||||
@@ -137,6 +145,11 @@ type Server struct {
|
||||
// never contend with each other.
|
||||
chatStreams sync.Map // uuid.UUID -> *chatStreamState
|
||||
|
||||
// streamMessageFetches coalesces concurrent chat stream durable
|
||||
// history reads. It is not a cache: once a shared fetch
|
||||
// completes, future reads hit the database again.
|
||||
streamMessageFetches singleflight.Group[string, []database.ChatMessage]
|
||||
|
||||
// workspaceMCPToolsCache caches workspace MCP tool definitions
|
||||
// per chat to avoid re-fetching on every turn. The cache is
|
||||
// keyed by chat ID and invalidated when the agent changes.
|
||||
@@ -3161,6 +3174,73 @@ func (p *Server) heartbeatTick(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func cloneChatMessagesForStream(messages []database.ChatMessage) []database.ChatMessage {
|
||||
cloned := slices.Clone(messages)
|
||||
for i := range cloned {
|
||||
cloned[i].Content.RawMessage = slices.Clone(cloned[i].Content.RawMessage)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
// streamSharedHistoryFetchContext detaches subscriber cancellation from a
|
||||
// shared history fetch and runs it under a server-owned timeout budget.
|
||||
// Shared work should not inherit the winner's request deadline.
|
||||
func streamSharedHistoryFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.WithoutCancel(ctx), chatStreamHistoryFetchTimeout)
|
||||
}
|
||||
|
||||
// streamSubscriberControlFetchContext keeps a control-path lookup tied to the
|
||||
// requesting subscriber while applying a fallback timeout when the caller has
|
||||
// no deadline.
|
||||
func streamSubscriberControlFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if _, ok := ctx.Deadline(); ok {
|
||||
return ctx, func() {}
|
||||
}
|
||||
return context.WithTimeout(ctx, chatStreamControlFetchTimeout)
|
||||
}
|
||||
|
||||
// getStreamChatMessages loads durable chat messages for an already-authorized
|
||||
// subscriber. Subscribe() must validate the caller before this helper is used.
|
||||
// The shared fetch intentionally runs as chatd so request identity and timeout
|
||||
// policy come from chatd rather than whichever caller won singleflight.
|
||||
func (p *Server) getStreamChatMessages(
|
||||
ctx context.Context,
|
||||
params database.GetChatMessagesByChatIDParams,
|
||||
) ([]database.ChatMessage, error) {
|
||||
messages, err := singleflightDoChan(
|
||||
ctx,
|
||||
&p.streamMessageFetches,
|
||||
fmt.Sprintf("chat-messages:%s:after:%d", params.ChatID, params.AfterID),
|
||||
func() ([]database.ChatMessage, error) {
|
||||
fetchCtx, cancel := streamSharedHistoryFetchContext(ctx)
|
||||
defer cancel()
|
||||
//nolint:gocritic // SubscribeAuthorized already validated the
|
||||
// caller; the shared singleflight fetch runs as chatd so the
|
||||
// leader's request identity cannot affect other authorized waiters.
|
||||
return p.db.GetChatMessagesByChatID(dbauthz.AsChatd(fetchCtx), params)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cloneChatMessagesForStream(messages), nil
|
||||
}
|
||||
|
||||
func subscribeWithInitialError(chatID uuid.UUID, message string) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
bool,
|
||||
) {
|
||||
events := make(chan codersdk.ChatStreamEvent)
|
||||
close(events)
|
||||
return []codersdk.ChatStreamEvent{{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatStreamError{Message: message},
|
||||
}}, events, func() {}, true
|
||||
}
|
||||
|
||||
func (p *Server) Subscribe(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
@@ -3175,9 +3255,40 @@ func (p *Server) Subscribe(
|
||||
if p == nil {
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
|
||||
chat, err := p.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
p.logger.Warn(ctx, "failed to load chat for stream subscription",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return subscribeWithInitialError(chatID, "failed to load initial snapshot")
|
||||
}
|
||||
return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID)
|
||||
}
|
||||
|
||||
// SubscribeAuthorized subscribes an already-authorized chat to merged stream
|
||||
// updates. The passed chat row proves authorization, but SubscribeAuthorized
|
||||
// still reloads the chat after the stream subscriptions are armed so the
|
||||
// initial status and relay setup use fresh state.
|
||||
func (p *Server) SubscribeAuthorized(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
requestHeader http.Header,
|
||||
afterMessageID int64,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
bool,
|
||||
) {
|
||||
if p == nil {
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
chatID := chat.ID
|
||||
|
||||
// Subscribe to the local stream for message_parts and same-replica
|
||||
// persisted messages.
|
||||
@@ -3241,6 +3352,34 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
|
||||
cancel := func() {
|
||||
mergedCancel()
|
||||
for _, cancelFn := range allCancels {
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-read the chat after the local/pubsub subscriptions are active so
|
||||
// the initial status event and any enterprise relay setup use fresh
|
||||
// state instead of the middleware-loaded row.
|
||||
refreshCtx, refreshCancel := streamSubscriberControlFetchContext(ctx)
|
||||
snapshotChat, err := func() (database.Chat, error) {
|
||||
defer refreshCancel()
|
||||
//nolint:gocritic // SubscribeAuthorized already validated the
|
||||
// caller; this refresh only loads the latest status/worker for
|
||||
// the already-authorized stream subscription.
|
||||
return p.db.GetChatByID(dbauthz.AsChatd(refreshCtx), chatID)
|
||||
}()
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to refresh chat for stream subscription; using stale state",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
snapshotChat = chat
|
||||
}
|
||||
|
||||
// Build initial snapshot synchronously. The pubsub subscription
|
||||
// is already active so no notifications can be lost during this
|
||||
// window.
|
||||
@@ -3256,7 +3395,7 @@ func (p *Server) Subscribe(
|
||||
// caller already has messages up to that ID (e.g. from the REST
|
||||
// endpoint), so we only fetch newer ones to avoid sending
|
||||
// duplicate data.
|
||||
messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
messages, err := p.getStreamChatMessages(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: afterMessageID,
|
||||
})
|
||||
@@ -3281,8 +3420,12 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
|
||||
// Load initial queue.
|
||||
queued, err := p.db.GetChatQueuedMessages(ctx, chatID)
|
||||
// Load initial queue. Queue snapshots are intentionally not
|
||||
// singleflighted because a chat-scoped key cannot distinguish the
|
||||
// pre- and post-notification queue state.
|
||||
queueCtx, queueCancel := streamSubscriberControlFetchContext(ctx)
|
||||
queued, err := p.db.GetChatQueuedMessages(queueCtx, chatID)
|
||||
queueCancel()
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "failed to load initial queued messages",
|
||||
slog.Error(err),
|
||||
@@ -3301,35 +3444,18 @@ func (p *Server) Subscribe(
|
||||
})
|
||||
}
|
||||
|
||||
// Get initial chat state to determine if we need a relay.
|
||||
chat, chatErr := p.db.GetChatByID(ctx, chatID)
|
||||
|
||||
// Include the current chat status in the snapshot so the
|
||||
// frontend can gate message_part processing correctly from
|
||||
// the very first batch, without waiting for a separate REST
|
||||
// query.
|
||||
if chatErr != nil {
|
||||
p.logger.Error(ctx, "failed to load initial chat state",
|
||||
slog.Error(chatErr),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatStreamError{Message: "failed to load initial snapshot"},
|
||||
})
|
||||
} else {
|
||||
statusEvent := codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chatID,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatus(chat.Status),
|
||||
},
|
||||
}
|
||||
// Prepend so the frontend sees the status before any
|
||||
// message_part events.
|
||||
initialSnapshot = append([]codersdk.ChatStreamEvent{statusEvent}, initialSnapshot...)
|
||||
// Include the current chat status in the snapshot so the frontend can gate
|
||||
// message_part processing correctly from the very first batch, without
|
||||
// waiting for a separate REST query.
|
||||
statusEvent := codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chatID,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatus(snapshotChat.Status),
|
||||
},
|
||||
}
|
||||
// Prepend so the frontend sees the status before any message_part events.
|
||||
initialSnapshot = append([]codersdk.ChatStreamEvent{statusEvent}, initialSnapshot...)
|
||||
|
||||
// Track the highest durable message ID delivered to this subscriber,
|
||||
// whether it came from the initial DB snapshot, the same-replica local
|
||||
@@ -3339,18 +3465,17 @@ func (p *Server) Subscribe(
|
||||
lastMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
// When an enterprise SubscribeFn is provided and the chat
|
||||
// lookup succeeded, call it to get relay events (message_parts
|
||||
// from remote replicas). OSS now owns pubsub subscription,
|
||||
// message catch-up, queue updates, and status forwarding;
|
||||
// enterprise only manages relay dialing.
|
||||
// When an enterprise SubscribeFn is provided, call it to get relay events
|
||||
// (message_parts from remote replicas). OSS owns pubsub subscription,
|
||||
// message catch-up, queue updates, and status forwarding; enterprise only
|
||||
// manages relay dialing.
|
||||
var relayEvents <-chan codersdk.ChatStreamEvent
|
||||
var statusNotifications chan StatusNotification
|
||||
if p.subscribeFn != nil && chatErr == nil {
|
||||
if p.subscribeFn != nil {
|
||||
statusNotifications = make(chan StatusNotification, 10)
|
||||
relayEvents = p.subscribeFn(mergedCtx, SubscribeFnParams{
|
||||
ChatID: chatID,
|
||||
Chat: chat,
|
||||
Chat: snapshotChat,
|
||||
WorkerID: p.workerID,
|
||||
StatusNotifications: statusNotifications,
|
||||
RequestHeader: requestHeader,
|
||||
@@ -3407,7 +3532,7 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
lastMessageID = event.Message.ID
|
||||
}
|
||||
} else if newMessages, msgErr := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
|
||||
} else if newMessages, msgErr := p.getStreamChatMessages(mergedCtx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: lastMessageID,
|
||||
}); msgErr != nil {
|
||||
@@ -3495,7 +3620,9 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
if notify.QueueUpdate {
|
||||
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(mergedCtx, chatID)
|
||||
queueCtx, queueCancel := streamSubscriberControlFetchContext(mergedCtx)
|
||||
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(queueCtx, chatID)
|
||||
queueCancel()
|
||||
if queueErr != nil {
|
||||
p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification",
|
||||
slog.F("chat_id", chatID),
|
||||
@@ -3571,14 +3698,6 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}()
|
||||
|
||||
cancel := func() {
|
||||
mergedCancel()
|
||||
for _, cancelFn := range allCancels {
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
return initialSnapshot, mergedEvents, cancel, true
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
@@ -1293,14 +1294,14 @@ func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -1336,14 +1337,14 @@ func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T)
|
||||
ChatID: chatID,
|
||||
Role: codersdk.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -1387,14 +1388,14 @@ func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 1,
|
||||
@@ -1436,14 +1437,14 @@ func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleUser,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
@@ -1473,14 +1474,14 @@ func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -1517,14 +1518,14 @@ func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) {
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -1557,14 +1558,14 @@ func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -1581,6 +1582,274 @@ func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
|
||||
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSubscribeAuthorizedFallsBackToStaleRowWhenRefreshFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
chatID := uuid.New()
|
||||
staleChat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
|
||||
state := server.getOrCreateStreamState(chatID)
|
||||
state.mu.Lock()
|
||||
state.buffer = []codersdk.ChatStreamEvent{{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
ChatID: chatID,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: "assistant",
|
||||
Part: codersdk.ChatMessageText("thinking"),
|
||||
},
|
||||
}}
|
||||
state.mu.Unlock()
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{}, xerrors.New("refresh failed")),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
)
|
||||
|
||||
initialSnapshot, events, cancel, ok := server.SubscribeAuthorized(ctx, staleChat, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer cancel()
|
||||
|
||||
require.Len(t, initialSnapshot, 2)
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeStatus, initialSnapshot[0].Type)
|
||||
require.NotNil(t, initialSnapshot[0].Status)
|
||||
require.Equal(t, codersdk.ChatStatusPending, initialSnapshot[0].Status.Status)
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, initialSnapshot[1].Type)
|
||||
require.NotNil(t, initialSnapshot[1].MessagePart)
|
||||
require.Equal(t, "thinking", initialSnapshot[1].MessagePart.Part.Text)
|
||||
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
||||
}
|
||||
|
||||
// TestGetStreamChatMessagesReturnsIndependentClones verifies that
|
||||
// successive calls to getStreamChatMessages return independent
|
||||
// copies so that one caller mutating a result does not corrupt
|
||||
// another caller's view. This is the key correctness property
|
||||
// of the singleflight wrappers — the coalescing itself is
|
||||
// already tested by tailscale.com/util/singleflight.
|
||||
func TestGetStreamChatMessagesReturnsIndependentClones(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
chatID := uuid.New()
|
||||
params := database.GetChatMessagesByChatIDParams{ChatID: chatID, AfterID: 0}
|
||||
message := database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: chatID,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: []byte(`"hello"`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), params).
|
||||
Return([]database.ChatMessage{message}, nil).AnyTimes()
|
||||
|
||||
first, err := server.getStreamChatMessages(ctx, params)
|
||||
require.NoError(t, err)
|
||||
second, err := server.getStreamChatMessages(ctx, params)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, first, 1)
|
||||
require.Len(t, second, 1)
|
||||
|
||||
// Mutate first; second must be unaffected.
|
||||
first[0].Content.RawMessage[0] = 'x'
|
||||
require.Equal(t, `"hello"`, string(second[0].Content.RawMessage))
|
||||
}
|
||||
|
||||
func TestGetStreamChatMessagesCanceledWaiterDoesNotPoisonOthers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
chatID := uuid.New()
|
||||
params := database.GetChatMessagesByChatIDParams{ChatID: chatID, AfterID: 0}
|
||||
message := database.ChatMessage{ID: 1, ChatID: chatID}
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), params).DoAndReturn(func(context.Context, database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
select {
|
||||
case <-started:
|
||||
default:
|
||||
close(started)
|
||||
}
|
||||
<-release
|
||||
return []database.ChatMessage{message}, nil
|
||||
}).AnyTimes()
|
||||
|
||||
// Start the leader goroutine and wait for it to block inside
|
||||
// the DB mock so the singleflight key is definitely in-flight.
|
||||
survivorCh := make(chan struct {
|
||||
messages []database.ChatMessage
|
||||
err error
|
||||
}, 1)
|
||||
go func() {
|
||||
msgs, err := server.getStreamChatMessages(ctx, params)
|
||||
survivorCh <- struct {
|
||||
messages []database.ChatMessage
|
||||
err error
|
||||
}{messages: msgs, err: err}
|
||||
}()
|
||||
waitForSignal(t, started)
|
||||
|
||||
// Now cancel a second caller while the leader is still
|
||||
// blocked. Because singleflightDoChan selects on ctx.Done,
|
||||
// the canceled caller must return immediately with
|
||||
// context.Canceled without waiting for the DB call.
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
_, err := server.getStreamChatMessages(cancelCtx, params)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
|
||||
// Release the leader; it must succeed independently.
|
||||
close(release)
|
||||
|
||||
select {
|
||||
case result := <-survivorCh:
|
||||
require.NoError(t, result.err)
|
||||
require.Len(t, result.messages, 1)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("surviving waiter did not return")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamChatMessagesUsesServerOwnedTimeoutDespiteLeaderDeadline(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
chatID := uuid.New()
|
||||
params := database.GetChatMessagesByChatIDParams{ChatID: chatID, AfterID: 0}
|
||||
message := database.ChatMessage{ID: 1, ChatID: chatID}
|
||||
|
||||
fetchCtxCh := make(chan context.Context, 1)
|
||||
release := make(chan struct{})
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), params).DoAndReturn(func(fetchCtx context.Context, _ database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
fetchCtxCh <- fetchCtx
|
||||
<-release
|
||||
return []database.ChatMessage{message}, nil
|
||||
}).Times(1)
|
||||
|
||||
started := time.Now()
|
||||
leaderCtx, leaderCancel := context.WithTimeout(context.Background(), testutil.IntervalFast)
|
||||
defer leaderCancel()
|
||||
|
||||
leaderErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := server.getStreamChatMessages(leaderCtx, params)
|
||||
leaderErrCh <- err
|
||||
}()
|
||||
|
||||
var fetchCtx context.Context
|
||||
select {
|
||||
case fetchCtx = <-fetchCtxCh:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for shared history fetch to start")
|
||||
}
|
||||
|
||||
deadline, ok := fetchCtx.Deadline()
|
||||
require.True(t, ok)
|
||||
require.WithinDuration(t, started.Add(chatStreamHistoryFetchTimeout), deadline, 500*time.Millisecond)
|
||||
|
||||
followerCh := make(chan struct {
|
||||
messages []database.ChatMessage
|
||||
err error
|
||||
}, 1)
|
||||
go func() {
|
||||
msgs, err := server.getStreamChatMessages(context.Background(), params)
|
||||
followerCh <- struct {
|
||||
messages []database.ChatMessage
|
||||
err error
|
||||
}{messages: msgs, err: err}
|
||||
}()
|
||||
|
||||
<-leaderCtx.Done()
|
||||
require.NoError(t, fetchCtx.Err())
|
||||
close(release)
|
||||
|
||||
select {
|
||||
case err := <-leaderErrCh:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for leader to observe its deadline")
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-followerCh:
|
||||
require.NoError(t, result.err)
|
||||
require.Len(t, result.messages, 1)
|
||||
require.Equal(t, int64(1), result.messages[0].ID)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for follower to receive shared fetch result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeRejectsUnauthorizedCallerBeforeSharedFetches(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
chatID := uuid.New()
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).
|
||||
Return(database.Chat{}, dbauthz.NotAuthorizedError{Err: xerrors.New("not authorized")})
|
||||
|
||||
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
require.False(t, ok)
|
||||
require.Nil(t, snapshot)
|
||||
require.Nil(t, events)
|
||||
require.Nil(t, cancel)
|
||||
|
||||
_, exists := server.chatStreams.Load(chatID)
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestSubscribeSurfacesTransientLookupFailureAsInitialError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
chatID := uuid.New()
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).
|
||||
Return(database.Chat{}, xerrors.New("transient lookup failure"))
|
||||
|
||||
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, cancel)
|
||||
require.Len(t, snapshot, 1)
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeError, snapshot[0].Type)
|
||||
require.Equal(t, chatID, snapshot[0].ChatID)
|
||||
require.Equal(t, "failed to load initial snapshot", snapshot[0].Error.Message)
|
||||
|
||||
_, open := <-events
|
||||
require.False(t, open)
|
||||
|
||||
_, exists := server.chatStreams.Load(chatID)
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -430,8 +430,13 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat := seedRemoteRunningChat(ctx, t, db, user, model, workerID, "relay-snapshot")
|
||||
staleChat := chat
|
||||
staleChat.Status = database.ChatStatusWaiting
|
||||
staleChat.WorkerID = uuid.NullUUID{}
|
||||
staleChat.StartedAt = sql.NullTime{}
|
||||
staleChat.HeartbeatAt = sql.NullTime{}
|
||||
|
||||
initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
||||
initialSnapshot, events, cancel, ok := subscriber.SubscribeAuthorized(ctx, staleChat, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
@@ -455,15 +460,15 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
|
||||
|
||||
require.Equal(t, []string{"snap-one", "snap-two", "live-part"}, receivedTexts)
|
||||
|
||||
// The initial snapshot should still contain the status event
|
||||
// from the OSS preamble.
|
||||
var hasStatus bool
|
||||
// The initial snapshot should contain the refreshed running status,
|
||||
// not the stale waiting status passed into SubscribeAuthorized.
|
||||
var snapshotStatus codersdk.ChatStatus
|
||||
for _, event := range initialSnapshot {
|
||||
if event.Type == codersdk.ChatStreamEventTypeStatus {
|
||||
hasStatus = true
|
||||
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
|
||||
snapshotStatus = event.Status.Status
|
||||
}
|
||||
}
|
||||
require.True(t, hasStatus, "initial snapshot should contain status event")
|
||||
require.Equal(t, codersdk.ChatStatusRunning, snapshotStatus)
|
||||
}
|
||||
|
||||
func TestSubscribeRetryEventAcrossInstances(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user