Compare commits

...

2 Commits

Author SHA1 Message Date
Ethan Dickson 3c02ecbee9 fix(coderd/x/chatd): make heartbeat register/unregister ownership-aware
Fix a narrow race where a same-worker back-to-back run could lose
heartbeat coverage. When processChat cleanup auto-promotes a queued
message (advancing generation and setting status=pending), AcquireChats
can grab the chat before the old goroutine's deferred unregisterHeartbeat
fires.

registerHeartbeat now replaces older-generation entries for the same
chat instead of silently skipping duplicates. The old entry is canceled
with ErrInterrupted before replacement.

unregisterHeartbeat now takes *heartbeatEntry instead of a chat ID and
uses pointer identity to guard deletion. If a newer run already replaced
the entry, the old run's deferred unregister is a no-op.
2026-04-10 10:59:50 +00:00
Ethan Dickson 44f0e00ee3 feat(coderd/x/chatd): two-plane control architecture with run_generation fencing
Replace the per-chat status-driven worker control with a two-plane
architecture:

1. Status plane (per-chat, batched OK, observational only) — stays on
   existing chat:stream:<chatID>, used by UI subscribers and relay.

2. Control plane (per-worker, low-volume, explicit invalidation) — new
   channel chat:control:<workerID>, carries structured ChatControlMessage
   with chat_id, run_generation, and reason.

Add run_generation BIGINT column to chats table as a monotonic fencing
token. Workers keep the generation in memory and use it at ownership
checkpoints (acquire, heartbeat, persist, cleanup). Control messages
are the explicit signal to stop; status fanout is never authoritative
for workers.

Key changes:
- Add AdvanceChatRunGenerationAndUpdateStatus SQL helper for mutations
  that invalidate the current run
- Fence heartbeat batching by (id, run_generation) pairs using
  generate_subscripts CTE
- Fence cleanup against newer generations to prevent same-worker
  old-run clobber
- Add shouldStartOwnedRun check after heartbeat registration
- Remove per-chat subscribeChatControl and DB-read-on-notify stopgap
- Replace with single per-worker subscribeWorkerControl
- Short-circuit local worker control via handleChatControlMessage

This makes worker correctness independent of status-notification timing,
batching, duplication, or reordering.
2026-04-10 10:59:50 +00:00
20 changed files with 1156 additions and 242 deletions
+11
View File
@@ -1561,6 +1561,17 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.Activi
return update(q.log, q.auth, fetch, q.db.ActivityBumpWorkspace)(ctx, arg)
}
func (q *querier) AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg database.AdvanceChatRunGenerationAndUpdateStatusParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.AdvanceChatRunGenerationAndUpdateStatus(ctx, arg)
}
func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
// Although this technically only reads users, only system-related functions
// should be allowed to call this.
+15 -3
View File
@@ -879,9 +879,10 @@ func (s *MethodTestSuite) TestChats() {
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
resultID := uuid.New()
arg := database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{resultID},
WorkerID: uuid.New(),
Now: time.Now(),
IDs: []uuid.UUID{resultID},
RunGenerations: []int64{1},
WorkerID: uuid.New(),
Now: time.Now(),
}
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
@@ -946,6 +947,17 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("AdvanceChatRunGenerationAndUpdateStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.AdvanceChatRunGenerationAndUpdateStatusParams{
ID: chat.ID,
Status: database.ChatStatusPending,
}
updated := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().AdvanceChatRunGenerationAndUpdateStatus(gomock.Any(), arg).Return(updated, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updated)
}))
s.Run("UpdateChatBuildAgentBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatBuildAgentBindingParams{
@@ -152,6 +152,14 @@ func (m queryMetricsStore) ActivityBumpWorkspace(ctx context.Context, arg databa
return r0
}
func (m queryMetricsStore) AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg database.AdvanceChatRunGenerationAndUpdateStatusParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.AdvanceChatRunGenerationAndUpdateStatus(ctx, arg)
m.queryLatencies.WithLabelValues("AdvanceChatRunGenerationAndUpdateStatus").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AdvanceChatRunGenerationAndUpdateStatus").Inc()
return r0, r1
}
func (m queryMetricsStore) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.AllUserIDs(ctx, includeSystem)
+15
View File
@@ -132,6 +132,21 @@ func (mr *MockStoreMockRecorder) ActivityBumpWorkspace(ctx, arg any) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivityBumpWorkspace", reflect.TypeOf((*MockStore)(nil).ActivityBumpWorkspace), ctx, arg)
}
// AdvanceChatRunGenerationAndUpdateStatus mocks base method.
func (m *MockStore) AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg database.AdvanceChatRunGenerationAndUpdateStatusParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AdvanceChatRunGenerationAndUpdateStatus", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AdvanceChatRunGenerationAndUpdateStatus indicates an expected call of AdvanceChatRunGenerationAndUpdateStatus.
func (mr *MockStoreMockRecorder) AdvanceChatRunGenerationAndUpdateStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AdvanceChatRunGenerationAndUpdateStatus", reflect.TypeOf((*MockStore)(nil).AdvanceChatRunGenerationAndUpdateStatus), ctx, arg)
}
// AllUserIDs mocks base method.
func (m *MockStore) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
+1
View File
@@ -1413,6 +1413,7 @@ CREATE TABLE chats (
workspace_id uuid,
title text DEFAULT 'New Chat'::text NOT NULL,
status chat_status DEFAULT 'waiting'::chat_status NOT NULL,
run_generation bigint DEFAULT 0 NOT NULL,
worker_id uuid,
started_at timestamp with time zone,
heartbeat_at timestamp with time zone,
@@ -0,0 +1,2 @@
ALTER TABLE chats
DROP COLUMN run_generation;
@@ -0,0 +1,2 @@
ALTER TABLE chats
ADD COLUMN run_generation BIGINT NOT NULL DEFAULT 0;
+1
View File
@@ -780,6 +780,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
&i.Chat.WorkspaceID,
&i.Chat.Title,
&i.Chat.Status,
&i.Chat.RunGeneration,
&i.Chat.WorkerID,
&i.Chat.StartedAt,
&i.Chat.HeartbeatAt,
+1
View File
@@ -4227,6 +4227,7 @@ type Chat struct {
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
Title string `db:"title" json:"title"`
Status ChatStatus `db:"status" json:"status"`
RunGeneration int64 `db:"run_generation" json:"run_generation"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
+4 -2
View File
@@ -52,6 +52,7 @@ type sqlcQuerier interface {
// We only bump if workspace shutdown is manual.
// We only bump when 5% of the deadline has elapsed.
ActivityBumpWorkspace(ctx context.Context, arg ActivityBumpWorkspaceParams) error
AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg AdvanceChatRunGenerationAndUpdateStatusParams) (Chat, error)
// AllUserIDs returns all UserIDs regardless of user status or deletion.
AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error)
ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
@@ -917,8 +918,9 @@ type sqlcQuerier interface {
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
// caller can detect stolen or completed chats via set-difference.
// worker and run generation. Returns the IDs that were actually
// updated so the caller can detect stolen, superseded, or
// completed chats via set-difference.
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
// Updates the cached injected context parts (AGENTS.md +
+140 -34
View File
@@ -4244,7 +4244,7 @@ WHERE
$3::int
)
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type AcquireChatsParams struct {
@@ -4270,6 +4270,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) (
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -4422,14 +4423,80 @@ func (q *sqlQuerier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal
return items, nil
}
const advanceChatRunGenerationAndUpdateStatus = `-- name: AdvanceChatRunGenerationAndUpdateStatus :one
UPDATE
chats
SET
run_generation = run_generation + 1,
status = $1::chat_status,
worker_id = $2::uuid,
started_at = $3::timestamptz,
heartbeat_at = $4::timestamptz,
last_error = $5::text,
updated_at = NOW()
WHERE
id = $6::uuid
RETURNING
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type AdvanceChatRunGenerationAndUpdateStatusParams struct {
Status ChatStatus `db:"status" json:"status"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
LastError sql.NullString `db:"last_error" json:"last_error"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg AdvanceChatRunGenerationAndUpdateStatusParams) (Chat, error) {
row := q.db.QueryRowContext(ctx, advanceChatRunGenerationAndUpdateStatus,
arg.Status,
arg.WorkerID,
arg.StartedAt,
arg.HeartbeatAt,
arg.LastError,
arg.ID,
)
var i Chat
err := row.Scan(
&i.ID,
&i.OwnerID,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.ParentChatID,
&i.RootChatID,
&i.LastModelConfigID,
&i.Archived,
&i.LastError,
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
const archiveChatByID = `-- name: ArchiveChatByID :many
WITH chats AS (
UPDATE chats
SET archived = true, pin_order = 0, updated_at = NOW()
WHERE id = $1::uuid OR root_chat_id = $1::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
)
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM chats
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
`
@@ -4449,6 +4516,7 @@ func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -4617,7 +4685,7 @@ func (q *sqlQuerier) DeleteOldChats(ctx context.Context, arg DeleteOldChatsParam
}
const getActiveChatsByAgentID = `-- name: GetActiveChatsByAgentID :many
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM chats
WHERE agent_id = $1::uuid
AND archived = false
@@ -4643,6 +4711,7 @@ func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.U
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -4678,7 +4747,7 @@ func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.U
const getChatByID = `-- name: GetChatByID :one
SELECT
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM
chats
WHERE
@@ -4694,6 +4763,7 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -4718,7 +4788,7 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error
}
const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats WHERE id = $1::uuid FOR UPDATE
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats WHERE id = $1::uuid FOR UPDATE
`
func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) {
@@ -4730,6 +4800,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -5803,7 +5874,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u
const getChats = `-- name: GetChats :many
SELECT
chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, chats.dynamic_tools,
chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.run_generation, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, chats.dynamic_tools,
EXISTS (
SELECT 1 FROM chat_messages cm
WHERE cm.chat_id = chats.id
@@ -5893,6 +5964,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha
&i.Chat.WorkspaceID,
&i.Chat.Title,
&i.Chat.Status,
&i.Chat.RunGeneration,
&i.Chat.WorkerID,
&i.Chat.StartedAt,
&i.Chat.HeartbeatAt,
@@ -5928,7 +6000,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha
}
const getChatsByWorkspaceIDs = `-- name: GetChatsByWorkspaceIDs :many
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM chats
WHERE archived = false
AND workspace_id = ANY($1::uuid[])
@@ -5950,6 +6022,7 @@ func (q *sqlQuerier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -6096,7 +6169,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
const getStaleChats = `-- name: GetStaleChats :many
SELECT
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM
chats
WHERE
@@ -6125,6 +6198,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -6210,6 +6284,7 @@ INSERT INTO chats (
title,
mode,
status,
run_generation,
mcp_server_ids,
labels,
dynamic_tools
@@ -6224,12 +6299,13 @@ INSERT INTO chats (
$8::text,
$9::chat_mode,
$10::chat_status,
COALESCE($11::uuid[], '{}'::uuid[]),
COALESCE($12::jsonb, '{}'::jsonb),
$13::jsonb
$11::bigint,
COALESCE($12::uuid[], '{}'::uuid[]),
COALESCE($13::jsonb, '{}'::jsonb),
$14::jsonb
)
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type InsertChatParams struct {
@@ -6243,6 +6319,7 @@ type InsertChatParams struct {
Title string `db:"title" json:"title"`
Mode NullChatMode `db:"mode" json:"mode"`
Status ChatStatus `db:"status" json:"status"`
RunGeneration int64 `db:"run_generation" json:"run_generation"`
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
Labels pqtype.NullRawMessage `db:"labels" json:"labels"`
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
@@ -6260,6 +6337,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
arg.Title,
arg.Mode,
arg.Status,
arg.RunGeneration,
pq.Array(arg.MCPServerIDs),
arg.Labels,
arg.DynamicTools,
@@ -6271,6 +6349,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -6797,9 +6876,9 @@ WITH chats AS (
archived = false,
updated_at = NOW()
WHERE id = $1::uuid OR root_chat_id = $1::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
)
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM chats
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
`
@@ -6823,6 +6902,7 @@ func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Cha
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -6922,7 +7002,7 @@ UPDATE chats SET
updated_at = NOW()
WHERE
id = $3::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatBuildAgentBindingParams struct {
@@ -6940,6 +7020,7 @@ func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg Update
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -6972,7 +7053,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatByIDParams struct {
@@ -6989,6 +7070,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -7013,29 +7095,46 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
}
const updateChatHeartbeats = `-- name: UpdateChatHeartbeats :many
WITH input AS (
SELECT
($3::uuid[])[i] AS id,
($4::bigint[])[i] AS run_generation
FROM
generate_subscripts($3::uuid[], 1) AS g(i)
)
UPDATE
chats
SET
heartbeat_at = $1::timestamptz
FROM
input
WHERE
id = ANY($2::uuid[])
AND worker_id = $3::uuid
AND status = 'running'::chat_status
RETURNING id
chats.id = input.id
AND chats.run_generation = input.run_generation
AND chats.worker_id = $2::uuid
AND chats.status = 'running'::chat_status
RETURNING chats.id
`
type UpdateChatHeartbeatsParams struct {
Now time.Time `db:"now" json:"now"`
IDs []uuid.UUID `db:"ids" json:"ids"`
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
Now time.Time `db:"now" json:"now"`
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
IDs []uuid.UUID `db:"ids" json:"ids"`
RunGenerations []int64 `db:"run_generations" json:"run_generations"`
}
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
// caller can detect stolen or completed chats via set-difference.
// worker and run generation. Returns the IDs that were actually
// updated so the caller can detect stolen, superseded, or
// completed chats via set-difference.
func (q *sqlQuerier) UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats, arg.Now, pq.Array(arg.IDs), arg.WorkerID)
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats,
arg.Now,
arg.WorkerID,
pq.Array(arg.IDs),
pq.Array(arg.RunGenerations),
)
if err != nil {
return nil, err
}
@@ -7066,7 +7165,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatLabelsByIDParams struct {
@@ -7083,6 +7182,7 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -7111,7 +7211,7 @@ UPDATE chats SET
last_injected_context = $1::jsonb
WHERE
id = $2::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatLastInjectedContextParams struct {
@@ -7132,6 +7232,7 @@ func (q *sqlQuerier) UpdateChatLastInjectedContext(ctx context.Context, arg Upda
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -7164,7 +7265,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatLastModelConfigByIDParams struct {
@@ -7181,6 +7282,7 @@ func (q *sqlQuerier) UpdateChatLastModelConfigByID(ctx context.Context, arg Upda
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -7231,7 +7333,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatMCPServerIDsParams struct {
@@ -7248,6 +7350,7 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -7402,7 +7505,7 @@ SET
WHERE
id = $6::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatStatusParams struct {
@@ -7430,6 +7533,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -7466,7 +7570,7 @@ SET
WHERE
id = $7::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatStatusPreserveUpdatedAtParams struct {
@@ -7496,6 +7600,7 @@ func (q *sqlQuerier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
@@ -7526,7 +7631,7 @@ UPDATE chats SET
agent_id = $3::uuid,
updated_at = NOW()
WHERE id = $4::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatWorkspaceBindingParams struct {
@@ -7550,6 +7655,7 @@ func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateC
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.RunGeneration,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
+35 -6
View File
@@ -398,6 +398,7 @@ INSERT INTO chats (
title,
mode,
status,
run_generation,
mcp_server_ids,
labels,
dynamic_tools
@@ -412,6 +413,7 @@ INSERT INTO chats (
@title::text,
sqlc.narg('mode')::chat_mode,
@status::chat_status,
@run_generation::bigint,
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb),
sqlc.narg('dynamic_tools')::jsonb
@@ -655,6 +657,22 @@ WHERE
RETURNING
*;
-- name: AdvanceChatRunGenerationAndUpdateStatus :one
UPDATE
chats
SET
run_generation = run_generation + 1,
status = @status::chat_status,
worker_id = sqlc.narg('worker_id')::uuid,
started_at = sqlc.narg('started_at')::timestamptz,
heartbeat_at = sqlc.narg('heartbeat_at')::timestamptz,
last_error = sqlc.narg('last_error')::text,
updated_at = NOW()
WHERE
id = @id::uuid
RETURNING
*;
-- name: UpdateChatStatusPreserveUpdatedAt :one
UPDATE
chats
@@ -688,17 +706,28 @@ WHERE
-- name: UpdateChatHeartbeats :many
-- Bumps the heartbeat timestamp for the given set of chat IDs,
-- provided they are still running and owned by the specified
-- worker. Returns the IDs that were actually updated so the
-- caller can detect stolen or completed chats via set-difference.
-- worker and run generation. Returns the IDs that were actually
-- updated so the caller can detect stolen, superseded, or
-- completed chats via set-difference.
WITH input AS (
SELECT
(@ids::uuid[])[i] AS id,
(@run_generations::bigint[])[i] AS run_generation
FROM
generate_subscripts(@ids::uuid[], 1) AS g(i)
)
UPDATE
chats
SET
heartbeat_at = @now::timestamptz
FROM
input
WHERE
id = ANY(@ids::uuid[])
AND worker_id = @worker_id::uuid
AND status = 'running'::chat_status
RETURNING id;
chats.id = input.id
AND chats.run_generation = input.run_generation
AND chats.worker_id = @worker_id::uuid
AND chats.status = 'running'::chat_status
RETURNING chats.id;
-- name: GetChatDiffStatusByChatID :one
SELECT
+58
View File
@@ -0,0 +1,58 @@
package pubsub
import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
)
// ChatControlReason identifies why a worker-scoped control message was sent.
type ChatControlReason string
const (
// ChatControlReasonInterrupt requests that the current generation stop
// running without fencing a newer generation. This is used for explicit stop
// actions that still allow the interrupted run to persist partial output.
ChatControlReasonInterrupt ChatControlReason = "interrupt"
// ChatControlReasonRestart requests that an older generation stop because a
// newer generation has been scheduled.
ChatControlReasonRestart ChatControlReason = "restart"
// ChatControlReasonArchive requests that the current generation stop because
// the chat is being archived.
ChatControlReasonArchive ChatControlReason = "archive"
// ChatControlReasonRecoverStale requests that the current generation stop
// because stale recovery fenced it off.
ChatControlReasonRecoverStale ChatControlReason = "recover_stale"
)
// ChatControlChannel returns the pubsub channel for worker-scoped control
// messages. Each worker subscribes to exactly one control channel.
func ChatControlChannel(workerID uuid.UUID) string {
return fmt.Sprintf("chat:control:%s", workerID)
}
// ChatControlMessage requests that a worker stop an active run for a chat.
// RunGeneration fences newer runs from stale control messages.
type ChatControlMessage struct {
ChatID uuid.UUID `json:"chat_id"`
RunGeneration int64 `json:"run_generation"`
Reason ChatControlReason `json:"reason,omitempty"`
}
// HandleChatControl wraps a typed callback for ChatControlMessage payloads.
func HandleChatControl(cb func(ctx context.Context, payload ChatControlMessage, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
cb(ctx, ChatControlMessage{}, err)
return
}
var payload ChatControlMessage
if err := json.Unmarshal(message, &payload); err != nil {
cb(ctx, ChatControlMessage{}, err)
return
}
cb(ctx, payload, nil)
}
}
+1 -1
View File
@@ -17,7 +17,7 @@ func ChatStreamNotifyChannel(chatID uuid.UUID) string {
// ChatStreamNotifyMessage is the payload published on the per-chat
// stream notification channel. Durable message content is still read
// from the database, while transient control events can be carried
// from the database, while lightweight status and relay hints travel
// inline for cross-replica delivery.
type ChatStreamNotifyMessage struct {
// AfterMessageID tells subscribers to query messages after this
+310 -139
View File
@@ -131,6 +131,7 @@ type Server struct {
providerAPIKeys chatprovider.ProviderAPIKeys
configCache *chatConfigCache
configCacheUnsubscribe func()
controlUnsubscribe func()
// chatStreams stores per-chat stream state. Using sync.Map
// gives each chat independent locking — concurrent chats
@@ -720,6 +721,7 @@ type chatStreamState struct {
type heartbeatEntry struct {
cancelWithCause context.CancelCauseFunc
chatID uuid.UUID
runGeneration int64
workspaceID uuid.NullUUID
logger slog.Logger
}
@@ -745,11 +747,12 @@ var (
// ErrEditedMessageNotUser indicates a non-user message edit attempt.
ErrEditedMessageNotUser = xerrors.New("only user messages can be edited")
// errChatTakenByOtherWorker is a sentinel used inside the
// processChat cleanup transaction to signal that another
// worker acquired the chat, so all post-TX side effects
// (status publish, pubsub, web push) must be skipped.
errChatTakenByOtherWorker = xerrors.New("chat acquired by another worker")
// errChatSuperseded is a sentinel used inside the processChat
// cleanup transaction to signal that a newer worker ownership or
// run generation has superseded the local run. In that case all
// post-TX side effects (status publish, pubsub, web push) must
// be skipped.
errChatSuperseded = xerrors.New("chat run superseded")
)
// UsageLimitExceededError indicates the user has exceeded their chat spend
@@ -894,8 +897,9 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
Mode: opts.ChatMode,
// Chats created with an initial user message start pending.
// Waiting is reserved for idle chats with no pending work.
Status: database.ChatStatusPending,
MCPServerIDs: opts.MCPServerIDs,
Status: database.ChatStatusPending,
RunGeneration: 1,
MCPServerIDs: opts.MCPServerIDs,
Labels: pqtype.NullRawMessage{
RawMessage: labelsJSON,
Valid: true,
@@ -1132,11 +1136,11 @@ func (p *Server) SendMessage(
})
// For interrupt behavior, signal the running loop to
// stop. setChatWaiting publishes a status notification
// that the worker's control subscriber detects, causing
// it to cancel with ErrInterrupted. The deferred cleanup
// in processChat then auto-promotes the queued message
// after persisting the partial assistant response.
// stop. setChatWaiting publishes a worker-scoped control
// message so the active run cancels with ErrInterrupted.
// The deferred cleanup in processChat then auto-promotes
// the queued message after persisting the partial
// assistant response.
if busyBehavior == SendMessageBusyBehaviorInterrupt {
updatedChat, err := p.setChatWaiting(ctx, opts.ChatID)
if err != nil {
@@ -1210,7 +1214,12 @@ func (p *Server) EditMessage(
return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
}
var result EditMessageResult
var (
result EditMessageResult
controlWorkerID uuid.UUID
controlMsg coderdpubsub.ChatControlMessage
publishControl bool
)
txErr := p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
if err != nil {
@@ -1272,7 +1281,7 @@ func (p *Server) EditMessage(
if err != nil {
return xerrors.Errorf("delete queued messages: %w", err)
}
updatedChat, err := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
updatedChat, err := tx.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
ID: opts.ChatID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
@@ -1284,6 +1293,11 @@ func (p *Server) EditMessage(
return xerrors.Errorf("set chat pending: %w", err)
}
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
lockedChat,
updatedChat.RunGeneration,
coderdpubsub.ChatControlReasonRestart,
)
result.Message = newMessage
result.Chat = updatedChat
return nil
@@ -1291,6 +1305,9 @@ func (p *Server) EditMessage(
if txErr != nil {
return EditMessageResult{}, txErr
}
if publishControl {
p.publishChatControl(controlWorkerID, controlMsg)
}
p.publishEditedMessage(opts.ChatID, result.Message)
p.publishEvent(opts.ChatID, codersdk.ChatStreamEvent{
@@ -1316,9 +1333,14 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
return xerrors.New("chat_id is required")
}
statusChat := chat
interrupted := false
var archivedChats []database.Chat
var (
statusChat = chat
interrupted bool
archivedChats []database.Chat
controlWorkerID uuid.UUID
controlMsg coderdpubsub.ChatControlMessage
publishControl bool
)
if err := p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
if err != nil {
@@ -1341,6 +1363,11 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
if err != nil {
return xerrors.Errorf("set chat waiting before archive: %w", err)
}
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
lockedChat,
statusChat.RunGeneration,
coderdpubsub.ChatControlReasonArchive,
)
interrupted = true
}
@@ -1354,6 +1381,9 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
}
if interrupted {
if publishControl {
p.publishChatControl(controlWorkerID, controlMsg)
}
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
}
@@ -1607,7 +1637,13 @@ func (p *Server) SubmitToolResults(
// The GetLastChatMessageByRole lookup and all subsequent
// validation and persistence run inside a single transaction
// so the assistant message cannot change between reads.
var statusConflict *ToolResultStatusConflictError
var (
statusConflict *ToolResultStatusConflictError
updatedChat database.Chat
controlWorkerID uuid.UUID
controlMsg coderdpubsub.ChatControlMessage
publishControl bool
)
txErr := p.db.InTx(func(tx database.Store) error {
// Authoritative status check under row lock.
locked, lockErr := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
@@ -1761,22 +1797,33 @@ func (p *Server) SubmitToolResults(
}
// Transition chat to pending.
if _, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
updatedChat, err = tx.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
ID: opts.ChatID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
}); updateErr != nil {
return xerrors.Errorf("update chat status: %w", updateErr)
})
if err != nil {
return xerrors.Errorf("update chat status: %w", err)
}
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
locked,
updatedChat.RunGeneration,
coderdpubsub.ChatControlReasonRestart,
)
return nil
}, nil)
if txErr != nil {
return txErr
}
if publishControl {
p.publishChatControl(controlWorkerID, controlMsg)
}
p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
// Wake the chatd run loop so it processes the chat immediately.
p.signalWake()
@@ -2318,7 +2365,12 @@ func (p *Server) RefreshStatus(ctx context.Context, chatID uuid.UUID) error {
}
func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database.Chat, error) {
var updatedChat database.Chat
var (
updatedChat database.Chat
controlWorkerID uuid.UUID
controlMsg coderdpubsub.ChatControlMessage
publishControl bool
)
err := p.db.InTx(func(tx database.Store) error {
locked, lockErr := tx.GetChatByIDForUpdate(ctx, chatID)
if lockErr != nil {
@@ -2341,11 +2393,22 @@ func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
return updateErr
if updateErr != nil {
return updateErr
}
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
locked,
updatedChat.RunGeneration,
coderdpubsub.ChatControlReasonInterrupt,
)
return nil
}, nil)
if err != nil {
return database.Chat{}, err
}
if publishControl {
p.publishChatControl(controlWorkerID, controlMsg)
}
p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID)
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
return updatedChat, nil
@@ -2646,7 +2709,7 @@ func insertUserMessageAndSetPending(
return message, lockedChat, nil
}
updatedChat, err := store.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
updatedChat, err := store.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
ID: lockedChat.ID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
@@ -2786,6 +2849,7 @@ func New(cfg Config) *Server {
}
p.configCacheUnsubscribe = cancelConfigSub
}
p.controlUnsubscribe = p.subscribeWorkerControl(ctx)
go p.start(ctx)
return p
@@ -2874,7 +2938,7 @@ func (p *Server) processOnce(ctx context.Context) {
context.WithoutCancel(ctx), 10*time.Second,
)
for _, chat := range chats {
_, updateErr := p.db.UpdateChatStatus(releaseCtx, database.UpdateChatStatusParams{
_, updateErr := p.db.AdvanceChatRunGenerationAndUpdateStatus(releaseCtx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
ID: chat.ID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
@@ -3075,21 +3139,35 @@ func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) {
func (p *Server) registerHeartbeat(entry *heartbeatEntry) {
p.heartbeatMu.Lock()
defer p.heartbeatMu.Unlock()
if _, exists := p.heartbeatRegistry[entry.chatID]; exists {
p.logger.Warn(context.Background(),
"duplicate heartbeat registration, skipping",
slog.F("chat_id", entry.chatID))
return
if existing, exists := p.heartbeatRegistry[entry.chatID]; exists {
if entry.runGeneration <= existing.runGeneration {
p.logger.Warn(context.Background(),
"duplicate heartbeat registration, skipping",
slog.F("chat_id", entry.chatID),
slog.F("existing_run_generation", existing.runGeneration),
slog.F("incoming_run_generation", entry.runGeneration))
return
}
// A newer generation for the same chat can start before the old
// processChat goroutine finishes unwinding its defers. Replace the
// stale entry now so the new run keeps heartbeat coverage and local
// worker-scoped control handling.
existing.cancelWithCause(chatloop.ErrInterrupted)
}
p.heartbeatRegistry[entry.chatID] = entry
}
// unregisterHeartbeat removes a chat from the centralized
// heartbeat loop when chat processing finishes.
func (p *Server) unregisterHeartbeat(chatID uuid.UUID) {
func (p *Server) unregisterHeartbeat(entry *heartbeatEntry) {
p.heartbeatMu.Lock()
defer p.heartbeatMu.Unlock()
delete(p.heartbeatRegistry, chatID)
if p.heartbeatRegistry[entry.chatID] == entry {
delete(p.heartbeatRegistry, entry.chatID)
}
}
// heartbeatLoop runs in a single goroutine, issuing one batch
@@ -3120,16 +3198,22 @@ func (p *Server) heartbeatTick(ctx context.Context) {
return
}
// Collect the IDs we believe we own.
ids := slices.Collect(maps.Keys(snapshot))
// Collect the IDs and generations we believe we own.
ids := make([]uuid.UUID, 0, len(snapshot))
runGenerations := make([]int64, 0, len(snapshot))
for id, entry := range snapshot {
ids = append(ids, id)
runGenerations = append(runGenerations, entry.runGeneration)
}
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
// access for batch-updating heartbeats.
chatdCtx := dbauthz.AsChatd(ctx)
updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{
IDs: ids,
WorkerID: p.workerID,
Now: p.clock.Now(),
IDs: ids,
RunGenerations: runGenerations,
WorkerID: p.workerID,
Now: p.clock.Now(),
})
if err != nil {
p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err))
@@ -3190,7 +3274,7 @@ func (p *Server) Subscribe(
var allCancels []func()
allCancels = append(allCancels, localCancel)
// Subscribe to pubsub for durable and structured control
// Subscribe to pubsub for durable and structured stream
// events (status, messages, queue updates, retry, errors).
// When pubsub is nil (e.g. in-memory
// single-instance) we skip this and deliver all local events.
@@ -3846,62 +3930,110 @@ func (p *Server) publishMessagePart(chatID uuid.UUID, role codersdk.ChatMessageR
})
}
func shouldCancelChatFromControlNotification(
notify coderdpubsub.ChatStreamNotifyMessage,
workerID uuid.UUID,
) bool {
status := database.ChatStatus(strings.TrimSpace(notify.Status))
switch status {
case database.ChatStatusWaiting, database.ChatStatusPending, database.ChatStatusError:
func controlMessageForWorker(
chat database.Chat,
runGeneration int64,
reason coderdpubsub.ChatControlReason,
) (uuid.UUID, coderdpubsub.ChatControlMessage, bool) {
if !chat.WorkerID.Valid {
return uuid.Nil, coderdpubsub.ChatControlMessage{}, false
}
return chat.WorkerID.UUID, coderdpubsub.ChatControlMessage{
ChatID: chat.ID,
RunGeneration: runGeneration,
Reason: reason,
}, true
}
func interruptsCurrentRun(reason coderdpubsub.ChatControlReason) bool {
switch reason {
case coderdpubsub.ChatControlReasonInterrupt, coderdpubsub.ChatControlReasonArchive:
return true
case database.ChatStatusRunning:
worker := strings.TrimSpace(notify.WorkerID)
if worker == "" {
return false
}
notifyWorkerID, err := uuid.Parse(worker)
if err != nil {
return false
}
return notifyWorkerID != workerID
default:
return false
}
}
func (p *Server) subscribeChatControl(
ctx context.Context,
chatID uuid.UUID,
cancel context.CancelCauseFunc,
logger slog.Logger,
) func() {
func shouldInterruptActiveRunFromControlMessage(
entry *heartbeatEntry,
msg coderdpubsub.ChatControlMessage,
) bool {
if entry == nil || msg.ChatID != entry.chatID {
return false
}
if msg.RunGeneration > entry.runGeneration {
return true
}
return msg.RunGeneration == entry.runGeneration && interruptsCurrentRun(msg.Reason)
}
func (p *Server) handleChatControlMessage(ctx context.Context, msg coderdpubsub.ChatControlMessage) {
if msg.ChatID == uuid.Nil {
return
}
p.heartbeatMu.Lock()
entry := p.heartbeatRegistry[msg.ChatID]
p.heartbeatMu.Unlock()
if !shouldInterruptActiveRunFromControlMessage(entry, msg) {
return
}
entry.logger.Info(ctx, "interrupting active run from control message",
slog.F("reason", msg.Reason),
slog.F("incoming_run_generation", msg.RunGeneration),
slog.F("local_run_generation", entry.runGeneration),
)
entry.cancelWithCause(chatloop.ErrInterrupted)
}
func (p *Server) publishChatControl(workerID uuid.UUID, msg coderdpubsub.ChatControlMessage) {
if workerID == uuid.Nil {
return
}
if workerID == p.workerID {
p.handleChatControlMessage(context.Background(), msg)
return
}
if p.pubsub == nil {
return
}
payload, err := json.Marshal(msg)
if err != nil {
p.logger.Error(context.Background(), "failed to marshal chat control message",
slog.F("chat_id", msg.ChatID),
slog.F("worker_id", workerID),
slog.Error(err),
)
return
}
if err := p.pubsub.Publish(coderdpubsub.ChatControlChannel(workerID), payload); err != nil {
p.logger.Error(context.Background(), "failed to publish chat control message",
slog.F("chat_id", msg.ChatID),
slog.F("worker_id", workerID),
slog.Error(err),
)
}
}
func (p *Server) subscribeWorkerControl(ctx context.Context) func() {
if p.pubsub == nil {
return nil
}
listener := func(_ context.Context, message []byte, err error) {
if err != nil {
logger.Warn(ctx, "chat control pubsub error", slog.Error(err))
return
}
var notify coderdpubsub.ChatStreamNotifyMessage
if unmarshalErr := json.Unmarshal(message, &notify); unmarshalErr != nil {
logger.Warn(ctx, "failed to unmarshal chat control notify", slog.Error(unmarshalErr))
return
}
if shouldCancelChatFromControlNotification(notify, p.workerID) {
cancel(chatloop.ErrInterrupted)
}
}
controlCancel, err := p.pubsub.SubscribeWithErr(
coderdpubsub.ChatStreamNotifyChannel(chatID),
listener,
coderdpubsub.ChatControlChannel(p.workerID),
coderdpubsub.HandleChatControl(func(ctx context.Context, msg coderdpubsub.ChatControlMessage, err error) {
if err != nil {
p.logger.Warn(ctx, "chat control pubsub error", slog.Error(err))
return
}
p.handleChatControlMessage(ctx, msg)
}),
)
if err != nil {
logger.Warn(ctx, "failed to subscribe to chat control notifications", slog.Error(err))
p.logger.Error(ctx, "failed to subscribe to worker chat controls", slog.Error(err))
return nil
}
return controlCancel
@@ -4023,6 +4155,21 @@ func (p *Server) trackWorkspaceUsage(
return wsID
}
func (p *Server) shouldStartOwnedRun(ctx context.Context, chat database.Chat) bool {
latest, err := p.db.GetChatByID(ctx, chat.ID)
if err != nil {
p.logger.Warn(ctx, "failed to verify chat ownership before start",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return true
}
return latest.RunGeneration == chat.RunGeneration &&
latest.Status == database.ChatStatusRunning &&
latest.WorkerID.Valid &&
latest.WorkerID.UUID == p.workerID
}
func (p *Server) processChat(ctx context.Context, chat database.Chat) {
logger := p.logger.With(slog.F("chat_id", chat.ID))
logger.Info(ctx, "processing chat request")
@@ -4030,42 +4177,30 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
chatCtx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
// Gate the control subscriber behind a channel that is closed
// after we publish "running" status. This prevents stale
// pubsub notifications (e.g. the "pending" notification from
// SendMessage that triggered this processing) from
// interrupting us before we start work. Due to async
// PostgreSQL NOTIFY delivery, a notification published before
// subscribeChatControl registers its queue can still arrive
// after registration.
controlArmed := make(chan struct{})
gatedCancel := func(cause error) {
select {
case <-controlArmed:
cancel(cause)
default:
logger.Debug(ctx, "ignoring control notification before armed")
}
}
controlCancel := p.subscribeChatControl(chatCtx, chat.ID, gatedCancel, logger)
defer func() {
if controlCancel != nil {
controlCancel()
}
}()
// Register with the centralized heartbeat loop instead of
// running a per-chat goroutine. The loop issues a single batch
// UPDATE for all chats on this worker and detects stolen chats
// via set-difference.
p.registerHeartbeat(&heartbeatEntry{
heartbeat := &heartbeatEntry{
cancelWithCause: cancel,
chatID: chat.ID,
runGeneration: chat.RunGeneration,
workspaceID: chat.WorkspaceID,
logger: logger,
})
defer p.unregisterHeartbeat(chat.ID)
}
p.registerHeartbeat(heartbeat)
defer p.unregisterHeartbeat(heartbeat)
// Re-check ownership after registration so a mutation that won the
// race against AcquireChats cannot start work simply because its
// control message arrived before the local run was registered.
if !p.shouldStartOwnedRun(context.WithoutCancel(chatCtx), chat) {
logger.Info(ctx, "chat no longer owned before processing start",
slog.F("run_generation", chat.RunGeneration),
)
cancel(chatloop.ErrInterrupted)
return
}
// Start buffering stream events BEFORE publishing the running
// status. This closes a race where a subscriber sees
@@ -4098,12 +4233,6 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
Valid: true,
})
// Arm the control subscriber. Closing the channel is a
// happens-before guarantee in the Go memory model — any
// notification dispatched after this point will correctly
// interrupt processing.
close(controlArmed)
// Determine the final status and last error to set when we're done.
status := database.ChatStatusWaiting
wasInterrupted := false
@@ -4145,21 +4274,26 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
return xerrors.Errorf("lock chat for release: %w", lockErr)
}
// If another worker has already acquired this chat,
// bail out — we must not overwrite their running
// status or publish spurious events.
// If a newer run generation or another worker already owns
// this chat, bail out — we must not overwrite their status
// or publish spurious side effects.
if latestChat.RunGeneration != chat.RunGeneration {
return errChatSuperseded
}
if latestChat.Status == database.ChatStatusRunning &&
latestChat.WorkerID.Valid &&
latestChat.WorkerID.UUID != p.workerID {
return errChatTakenByOtherWorker
return errChatSuperseded
}
// If someone else already set the chat to pending (e.g.
// the promote endpoint), don't overwrite it — just clear
// the worker and let the processor pick it back up.
// during a transition from older code), don't overwrite it.
if latestChat.Status == database.ChatStatusPending {
status = database.ChatStatusPending
} else if status == database.ChatStatusWaiting && !latestChat.Archived {
status = latestChat.Status
updatedChat = latestChat
return nil
}
if status == database.ChatStatusWaiting && !latestChat.Archived {
// Queued messages were already admitted through SendMessage,
// so auto-promotion only preserves FIFO order here. Archived
// chats skip promotion so archiving behaves like a hard stop.
@@ -4173,20 +4307,31 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
}
var updateErr error
updatedChat, updateErr = tx.UpdateChatStatus(cleanupCtx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: status,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{String: lastError, Valid: lastError != ""},
})
if status == database.ChatStatusPending {
updatedChat, updateErr = tx.AdvanceChatRunGenerationAndUpdateStatus(cleanupCtx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
ID: chat.ID,
Status: status,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{String: lastError, Valid: lastError != ""},
})
} else {
updatedChat, updateErr = tx.UpdateChatStatus(cleanupCtx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: status,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{String: lastError, Valid: lastError != ""},
})
}
return updateErr
}, nil)
if errors.Is(err, errChatTakenByOtherWorker) {
// Another worker owns this chat now — skip all
// post-TX side effects (status publish, pubsub,
// web push) to avoid overwriting their state.
if errors.Is(err, errChatSuperseded) {
// A newer run owns this chat now — skip all post-TX side
// effects (status publish, pubsub, web push) to avoid
// overwriting their state.
return
}
if err != nil {
@@ -4800,13 +4945,17 @@ func (p *Server) runChat(
// already been cleared but we still want to persist
// the partial assistant response. We allow the write
// because the history has NOT been truncated — the
// user simply asked to stop. In contrast, EditMessage
// sets the chat to "pending" after truncating, so the
// pending check still correctly blocks stale writes.
// user simply asked to stop. In contrast, generation-
// advancing transitions such as EditMessage move the
// chat to a newer run generation, which blocks stale
// writes even if the old worker is still draining.
lockedChat, lockErr := tx.GetChatByIDForUpdate(persistCtx, chat.ID)
if lockErr != nil {
return xerrors.Errorf("lock chat for persist: %w", lockErr)
}
if lockedChat.RunGeneration != chat.RunGeneration {
return chatloop.ErrInterrupted
}
if !lockedChat.WorkerID.Valid || lockedChat.WorkerID.UUID != p.workerID {
// The worker_id was cleared. Only allow the persist
// if the chat transitioned to "waiting" (interrupt),
@@ -5914,6 +6063,12 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
// between GetStaleChats (a bare SELECT) and here, the chat's
// heartbeat may have been refreshed. We re-check freshness
// under the row lock before resetting.
var (
updatedChat database.Chat
controlWorkerID uuid.UUID
controlMsg coderdpubsub.ChatControlMessage
publishControl bool
)
err := p.db.InTx(func(tx database.Store) error {
locked, lockErr := tx.GetChatByIDForUpdate(ctx, chat.ID)
if lockErr != nil {
@@ -5980,7 +6135,8 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
// Reset so any replica can pick it up (pending) or
// the client sees the failure (error).
_, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
var updateErr error
updatedChat, updateErr = tx.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
ID: chat.ID,
Status: recoverStatus,
WorkerID: uuid.NullUUID{},
@@ -5991,13 +6147,24 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
if updateErr != nil {
return updateErr
}
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
locked,
updatedChat.RunGeneration,
coderdpubsub.ChatControlReasonRecoverStale,
)
recovered++
return nil
}, nil)
if err != nil {
p.logger.Error(ctx, "failed to recover stale chat",
slog.F("chat_id", chat.ID), slog.Error(err))
continue
}
if publishControl {
p.publishChatControl(controlWorkerID, controlMsg)
}
p.publishStatus(updatedChat.ID, updatedChat.Status, updatedChat.WorkerID)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
}
if recovered > 0 {
@@ -6203,6 +6370,10 @@ func (p *Server) Close() error {
p.configCacheUnsubscribe = nil
unsub()
}
if unsub := p.controlUnsubscribe; unsub != nil {
p.controlUnsubscribe = nil
unsub()
}
p.cancel()
<-p.closed
p.drainInflight()
+244 -30
View File
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"sync"
"testing"
"time"
@@ -2580,26 +2581,219 @@ func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage
}
}
// TestProcessChat_IgnoresStaleControlNotification verifies that
// processChat is not interrupted by a "pending" notification
// published before processing begins. This is the race that caused
// TestOpenAIReasoningWithWebSearchRoundTripStoreFalse to flake:
// SendMessage publishes "pending" via PostgreSQL NOTIFY, and due
// to async delivery the notification can arrive at the control
// subscriber after it registers but before the processor publishes
// "running".
func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
func TestShouldInterruptActiveRunFromControlMessage(t *testing.T) {
t.Parallel()
chatID := uuid.New()
entry := &heartbeatEntry{chatID: chatID, runGeneration: 7}
tests := []struct {
name string
entry *heartbeatEntry
msg coderdpubsub.ChatControlMessage
want bool
}{
{
name: "newer generation restarts",
entry: entry,
msg: coderdpubsub.ChatControlMessage{
ChatID: chatID,
RunGeneration: 8,
Reason: coderdpubsub.ChatControlReasonRestart,
},
want: true,
},
{
name: "equal generation interrupt cancels",
entry: entry,
msg: coderdpubsub.ChatControlMessage{
ChatID: chatID,
RunGeneration: 7,
Reason: coderdpubsub.ChatControlReasonInterrupt,
},
want: true,
},
{
name: "equal generation archive cancels",
entry: entry,
msg: coderdpubsub.ChatControlMessage{
ChatID: chatID,
RunGeneration: 7,
Reason: coderdpubsub.ChatControlReasonArchive,
},
want: true,
},
{
name: "equal generation restart is ignored",
entry: entry,
msg: coderdpubsub.ChatControlMessage{
ChatID: chatID,
RunGeneration: 7,
Reason: coderdpubsub.ChatControlReasonRestart,
},
want: false,
},
{
name: "older generation interrupt is ignored",
entry: entry,
msg: coderdpubsub.ChatControlMessage{
ChatID: chatID,
RunGeneration: 6,
Reason: coderdpubsub.ChatControlReasonInterrupt,
},
want: false,
},
{
name: "different chat is ignored",
entry: entry,
msg: coderdpubsub.ChatControlMessage{
ChatID: uuid.New(),
RunGeneration: 8,
Reason: coderdpubsub.ChatControlReasonRestart,
},
want: false,
},
{
name: "missing entry is ignored",
entry: nil,
msg: coderdpubsub.ChatControlMessage{
ChatID: chatID,
RunGeneration: 8,
Reason: coderdpubsub.ChatControlReasonRestart,
},
want: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, shouldInterruptActiveRunFromControlMessage(tt.entry, tt.msg))
})
}
}
func TestSubscribeWorkerControl_CancelsRegisteredRun(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ps := dbpubsub.NewInMemory()
chatID := uuid.New()
workerID := uuid.New()
runGeneration := int64(7)
server := &Server{
logger: logger,
pubsub: ps,
workerID: workerID,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
chatCtx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
entry := &heartbeatEntry{
cancelWithCause: cancel,
chatID: chatID,
runGeneration: runGeneration,
logger: logger,
}
server.registerHeartbeat(entry)
defer server.unregisterHeartbeat(entry)
controlCancel := server.subscribeWorkerControl(ctx)
require.NotNil(t, controlCancel)
defer controlCancel()
payload, err := json.Marshal(coderdpubsub.ChatControlMessage{
ChatID: chatID,
RunGeneration: runGeneration + 1,
Reason: coderdpubsub.ChatControlReasonRestart,
})
require.NoError(t, err)
require.NoError(t, ps.Publish(coderdpubsub.ChatControlChannel(workerID), payload))
require.Eventually(t, func() bool {
return errors.Is(context.Cause(chatCtx), chatloop.ErrInterrupted)
}, testutil.WaitShort, testutil.IntervalFast)
}
func TestRegisterHeartbeat_ReplacesOlderGenerationAndIgnoresStaleUnregister(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
chatID := uuid.New()
server := &Server{
logger: logger,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
oldCtx, oldCancel := context.WithCancelCause(ctx)
defer oldCancel(nil)
oldEntry := &heartbeatEntry{
cancelWithCause: oldCancel,
chatID: chatID,
runGeneration: 1,
logger: logger,
}
server.registerHeartbeat(oldEntry)
newCtx, newCancel := context.WithCancelCause(ctx)
defer newCancel(nil)
newEntry := &heartbeatEntry{
cancelWithCause: newCancel,
chatID: chatID,
runGeneration: 2,
logger: logger,
}
server.registerHeartbeat(newEntry)
require.ErrorIs(t, context.Cause(oldCtx), chatloop.ErrInterrupted)
require.NoError(t, context.Cause(newCtx))
server.heartbeatMu.Lock()
require.Same(t, newEntry, server.heartbeatRegistry[chatID])
server.heartbeatMu.Unlock()
server.unregisterHeartbeat(oldEntry)
server.heartbeatMu.Lock()
require.Same(t, newEntry, server.heartbeatRegistry[chatID])
server.heartbeatMu.Unlock()
server.unregisterHeartbeat(newEntry)
server.heartbeatMu.Lock()
_, exists := server.heartbeatRegistry[chatID]
server.heartbeatMu.Unlock()
require.False(t, exists)
}
// TestProcessChat_IgnoresStaleStatusNotification verifies that
// processChat is not interrupted by a stale "pending" status
// fanout delivered after the worker has already published
// "running". Worker control now lives on a separate per-worker
// channel, so delayed status fanout is observational only.
func TestProcessChat_IgnoresStaleStatusNotification(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ps := dbpubsub.NewInMemory()
ps := chattest.NewDelayedStatusPubsub(dbpubsub.NewInMemory())
clock := quartz.NewMock(t)
chatID := uuid.New()
workerID := uuid.New()
runGeneration := int64(7)
chatChannel := coderdpubsub.ChatStreamNotifyChannel(chatID)
ps.DelayStatus(chatChannel, string(database.ChatStatusPending))
server := &Server{
db: db,
@@ -2612,42 +2806,52 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
// Publish a stale "pending" notification on the control channel
// BEFORE processChat subscribes. In production this is the
// notification from SendMessage that triggered the processing.
staleNotify, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusPending),
})
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), staleNotify)
require.NoError(t, err)
require.NoError(t, ps.Publish(chatChannel, staleNotify))
// Track which status processChat writes during cleanup.
var finalStatus database.ChatStatus
cleanupDone := make(chan struct{})
allowModelResolution := make(chan struct{})
// The deferred cleanup in processChat runs a transaction.
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
database.Chat{
ID: chatID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
RunGeneration: runGeneration,
}, nil,
)
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(db)
},
)
db.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(
database.Chat{ID: chatID, Status: database.ChatStatusRunning, WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}}, nil,
database.Chat{
ID: chatID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
RunGeneration: runGeneration,
}, nil,
)
db.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, params database.UpdateChatStatusParams) (database.Chat, error) {
finalStatus = params.Status
close(cleanupDone)
return database.Chat{ID: chatID, Status: params.Status}, nil
return database.Chat{ID: chatID, Status: params.Status, RunGeneration: runGeneration}, nil
},
)
// resolveChatModel fails immediately — that's fine, we only
// need processChat to get past initialization without being
// interrupted by the stale notification.
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return(
database.ChatModelConfig{}, xerrors.New("no model configured"),
// resolveChatModel fails after the stale notification has been
// released. That ensures the test exercises delayed status fanout
// while the chat is already running.
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).DoAndReturn(
func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
<-allowModelResolution
return database.ChatModelConfig{}, xerrors.New("no model configured")
},
).AnyTimes()
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
@@ -2656,19 +2860,19 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
).AnyTimes()
db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes()
chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New()}
chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New(), RunGeneration: runGeneration}
go server.processChat(ctx, chat)
require.NoError(t, ps.WaitForStatusPublish(ctx, chatChannel, string(database.ChatStatusRunning)))
require.NoError(t, ps.ReleaseStatus(chatChannel, string(database.ChatStatusPending)))
close(allowModelResolution)
select {
case <-cleanupDone:
case <-ctx.Done():
t.Fatal("processChat did not complete")
}
// If the stale notification interrupted us, status would be
// "waiting" (the ErrInterrupted path). Since the gate blocked
// it, processChat reached runChat, which failed on model
// resolution → status is "error".
require.Equal(t, database.ChatStatusError, finalStatus,
"processChat should have reached runChat (error), not been interrupted (waiting)")
}
@@ -2710,16 +2914,19 @@ func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel1,
chatID: chat1,
runGeneration: 1,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel2,
chatID: chat2,
runGeneration: 2,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel3,
chatID: chat3,
runGeneration: 3,
logger: logger,
})
@@ -2729,6 +2936,12 @@ func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
require.Equal(t, workerID, params.WorkerID)
require.Len(t, params.IDs, 3)
require.Len(t, params.RunGenerations, 3)
got := make(map[uuid.UUID]int64, len(params.IDs))
for i, id := range params.IDs {
got[id] = params.RunGenerations[i]
}
require.Equal(t, map[uuid.UUID]int64{chat1: 1, chat2: 2, chat3: 3}, got)
// Return only chat1 and chat2 as surviving.
return []uuid.UUID{chat1, chat2}, nil
},
@@ -2784,6 +2997,7 @@ func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel,
chatID: chatID,
runGeneration: 11,
logger: logger,
})
+51 -8
View File
@@ -474,7 +474,7 @@ func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
}
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
func TestUpdateChatHeartbeatsRequiresOwnershipAndGeneration(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
@@ -503,24 +503,67 @@ func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
// Wrong worker_id should return no IDs.
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
WorkerID: uuid.New(),
Now: time.Now(),
IDs: []uuid.UUID{chat.ID},
RunGenerations: []int64{chat.RunGeneration},
WorkerID: uuid.New(),
Now: time.Now(),
})
require.NoError(t, err)
require.Empty(t, ids)
// Correct worker_id should return the chat's ID.
// Wrong generation should also return no IDs.
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
WorkerID: workerID,
Now: time.Now(),
IDs: []uuid.UUID{chat.ID},
RunGenerations: []int64{chat.RunGeneration + 1},
WorkerID: workerID,
Now: time.Now(),
})
require.NoError(t, err)
require.Empty(t, ids)
// Correct worker_id and generation should return the chat's ID.
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
RunGenerations: []int64{chat.RunGeneration},
WorkerID: workerID,
Now: time.Now(),
})
require.NoError(t, err)
require.Len(t, ids, 1)
require.Equal(t, chat.ID, ids[0])
}
func TestAdvanceChatRunGenerationAndUpdateStatus(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "generation-advance",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
require.EqualValues(t, 1, chat.RunGeneration)
updated, err := db.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
ID: chat.ID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
require.Equal(t, chat.RunGeneration+1, updated.RunGeneration)
require.Equal(t, database.ChatStatusPending, updated.Status)
}
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
t.Parallel()
+220
View File
@@ -0,0 +1,220 @@
package chattest
import (
"context"
"encoding/json"
"sync"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
)
type delayedStatusKey struct {
event string
status string
}
// DelayedStatusPubsub buffers selected status notifications until a test
// releases them. This lets tests exercise stale-notify races deterministically
// without depending on PostgreSQL delivery timing.
type DelayedStatusPubsub struct {
inner dbpubsub.Pubsub
mu sync.Mutex
delayEnabled map[delayedStatusKey]bool
delayedMessages map[delayedStatusKey][][]byte
subscribed map[string]bool
subscribeWait map[string]chan struct{}
published map[delayedStatusKey]bool
publishWait map[delayedStatusKey]chan struct{}
}
// NewDelayedStatusPubsub wraps a pubsub implementation with deterministic
// buffering for chosen status notifications.
func NewDelayedStatusPubsub(inner dbpubsub.Pubsub) *DelayedStatusPubsub {
return &DelayedStatusPubsub{
inner: inner,
delayEnabled: make(map[delayedStatusKey]bool),
delayedMessages: make(map[delayedStatusKey][][]byte),
subscribed: make(map[string]bool),
subscribeWait: make(map[string]chan struct{}),
published: make(map[delayedStatusKey]bool),
publishWait: make(map[delayedStatusKey]chan struct{}),
}
}
// DelayStatus starts buffering matching status notifications instead of
// publishing them immediately.
func (p *DelayedStatusPubsub) DelayStatus(event, status string) {
p.mu.Lock()
defer p.mu.Unlock()
p.delayEnabled[delayedStatusKey{event: event, status: status}] = true
}
// ReleaseStatus flushes the buffered notifications for an event/status pair in
// publish order and resumes immediate delivery for later notifications.
func (p *DelayedStatusPubsub) ReleaseStatus(event, status string) error {
key := delayedStatusKey{event: event, status: status}
p.mu.Lock()
delete(p.delayEnabled, key)
messages := p.delayedMessages[key]
delete(p.delayedMessages, key)
p.mu.Unlock()
for _, message := range messages {
if err := p.inner.Publish(event, message); err != nil {
return err
}
}
return nil
}
// WaitForSubscribe blocks until a subscriber registers for the event.
func (p *DelayedStatusPubsub) WaitForSubscribe(ctx context.Context, event string) error {
wait := p.subscribeWaiter(event)
select {
case <-ctx.Done():
return ctx.Err()
case <-wait:
return nil
}
}
// WaitForStatusPublish blocks until a matching status notification is
// published, whether or not it is currently buffered.
func (p *DelayedStatusPubsub) WaitForStatusPublish(ctx context.Context, event, status string) error {
wait := p.publishWaiter(delayedStatusKey{event: event, status: status})
select {
case <-ctx.Done():
return ctx.Err()
case <-wait:
return nil
}
}
func (p *DelayedStatusPubsub) Subscribe(event string, listener dbpubsub.Listener) (func(), error) {
cancel, err := p.inner.Subscribe(event, listener)
if err == nil {
p.markSubscribed(event)
}
return cancel, err
}
func (p *DelayedStatusPubsub) SubscribeWithErr(event string, listener dbpubsub.ListenerWithErr) (func(), error) {
cancel, err := p.inner.SubscribeWithErr(event, listener)
if err == nil {
p.markSubscribed(event)
}
return cancel, err
}
func (p *DelayedStatusPubsub) Publish(event string, message []byte) error {
status, ok := chatNotifyStatus(message)
if !ok {
return p.inner.Publish(event, message)
}
key := delayedStatusKey{event: event, status: status}
p.markPublished(key)
p.mu.Lock()
delay := p.delayEnabled[key]
if delay {
p.delayedMessages[key] = append(p.delayedMessages[key], append([]byte(nil), message...))
}
p.mu.Unlock()
if delay {
return nil
}
return p.inner.Publish(event, message)
}
func (p *DelayedStatusPubsub) Close() error {
return p.inner.Close()
}
func (p *DelayedStatusPubsub) markSubscribed(event string) {
p.mu.Lock()
defer p.mu.Unlock()
if p.subscribed[event] {
return
}
p.subscribed[event] = true
if wait, ok := p.subscribeWait[event]; ok {
close(wait)
delete(p.subscribeWait, event)
}
}
func (p *DelayedStatusPubsub) markPublished(key delayedStatusKey) {
p.mu.Lock()
defer p.mu.Unlock()
if p.published[key] {
return
}
p.published[key] = true
if wait, ok := p.publishWait[key]; ok {
close(wait)
delete(p.publishWait, key)
}
}
func (p *DelayedStatusPubsub) subscribeWaiter(event string) chan struct{} {
p.mu.Lock()
defer p.mu.Unlock()
return p.subscribeWaiterLocked(event)
}
func (p *DelayedStatusPubsub) subscribeWaiterLocked(event string) chan struct{} {
if p.subscribed[event] {
return closedSignal()
}
wait, ok := p.subscribeWait[event]
if !ok {
wait = make(chan struct{})
p.subscribeWait[event] = wait
}
return wait
}
func (p *DelayedStatusPubsub) publishWaiter(key delayedStatusKey) chan struct{} {
p.mu.Lock()
defer p.mu.Unlock()
return p.publishWaiterLocked(key)
}
func (p *DelayedStatusPubsub) publishWaiterLocked(key delayedStatusKey) chan struct{} {
if p.published[key] {
return closedSignal()
}
wait, ok := p.publishWait[key]
if !ok {
wait = make(chan struct{})
p.publishWait[key] = wait
}
return wait
}
func chatNotifyStatus(message []byte) (string, bool) {
var notify coderdpubsub.ChatStreamNotifyMessage
if err := json.Unmarshal(message, &notify); err != nil {
return "", false
}
if notify.Status == "" {
return "", false
}
return notify.Status, true
}
func closedSignal() chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}
+1
View File
@@ -470,6 +470,7 @@ func (p *Server) createChildSubagentChatWithOptions(
Title: title,
Mode: opts.chatMode,
Status: database.ChatStatusPending,
RunGeneration: 1,
MCPServerIDs: mcpServerIDs,
Labels: pqtype.NullRawMessage{
RawMessage: labelsJSON,
+36 -19
View File
@@ -1425,12 +1425,16 @@ func TestSubscribeRelayDialCanceledOnFastCompletion(t *testing.T) {
func TestSubscribeRelayEstablishedMidStream(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
db, rawPS := dbtestutil.NewDB(t)
ps := chattest.NewDelayedStatusPubsub(rawPS)
workerID := uuid.New()
subscriberID := uuid.New()
// Gate: worker blocks after first streaming request until we
// release it. This gives the relay time to establish.
// Gate: the worker cannot emit the first streaming chunk until we
// release it. This lets the test deliver a stale pending
// notification after the run is already current but before the LLM
// stream starts.
allowStreamingStart := make(chan struct{})
firstChunkEmitted := make(chan struct{})
continueStreaming := make(chan struct{})
@@ -1438,8 +1442,16 @@ func TestSubscribeRelayEstablishedMidStream(t *testing.T) {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("mid-stream-relay")
}
// Signal that the first streaming request was received,
// then block until released.
select {
case <-allowStreamingStart:
case <-req.Context().Done():
return chattest.OpenAIErrorResponse(http.StatusRequestTimeout, "request_canceled", "request canceled before streaming")
}
select {
case <-req.Context().Done():
return chattest.OpenAIErrorResponse(http.StatusRequestTimeout, "request_canceled", "request canceled before first chunk")
default:
}
select {
case <-firstChunkEmitted:
default:
@@ -1501,6 +1513,8 @@ func TestSubscribeRelayEstablishedMidStream(t *testing.T) {
// Create the chat in waiting state.
chat := seedWaitingChat(ctx, t, db, user, model, "mid-stream-relay")
chatChannel := coderdpubsub.ChatStreamNotifyChannel(chat.ID)
ps.DelayStatus(chatChannel, string(database.ChatStatusPending))
// Subscribe from the subscriber replica while the chat is idle.
_, events, subCancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
@@ -1515,6 +1529,23 @@ func TestSubscribeRelayEstablishedMidStream(t *testing.T) {
})
require.NoError(t, err)
// Wait for the subscriber to receive the running status, which
// proves the worker has started the current run. Release the
// delayed pending notification only after that point so it is
// definitively stale.
require.Eventually(t, func() bool {
select {
case event := <-events:
return event.Type == codersdk.ChatStreamEventTypeStatus &&
event.Status != nil &&
event.Status.Status == codersdk.ChatStatusRunning
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
require.NoError(t, ps.ReleaseStatus(chatChannel, string(database.ChatStatusPending)))
close(allowStreamingStart)
// Wait for the worker to reach the LLM (first streaming
// request). Also poll the chat status so we fail fast with a
// clear message if the worker errors out instead of timing
@@ -1543,20 +1574,6 @@ waitForStream:
}
}
// Wait for the subscriber to receive the running status, which
// triggers the relay. Because the dialer is non-blocking, the
// relay establishes promptly.
require.Eventually(t, func() bool {
select {
case event := <-events:
return event.Type == codersdk.ChatStreamEventTypeStatus &&
event.Status != nil &&
event.Status.Status == codersdk.ChatStatusRunning
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
// Now release the worker to continue streaming.
close(continueStreaming)