Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c02ecbee9 | |||
| 44f0e00ee3 |
@@ -1561,6 +1561,17 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.Activi
|
||||
return update(q.log, q.auth, fetch, q.db.ActivityBumpWorkspace)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg database.AdvanceChatRunGenerationAndUpdateStatusParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.AdvanceChatRunGenerationAndUpdateStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
|
||||
// Although this technically only reads users, only system-related functions
|
||||
// should be allowed to call this.
|
||||
|
||||
@@ -879,9 +879,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
resultID := uuid.New()
|
||||
arg := database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{resultID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
IDs: []uuid.UUID{resultID},
|
||||
RunGenerations: []int64{1},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
|
||||
@@ -946,6 +947,17 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("AdvanceChatRunGenerationAndUpdateStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.AdvanceChatRunGenerationAndUpdateStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusPending,
|
||||
}
|
||||
updated := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().AdvanceChatRunGenerationAndUpdateStatus(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updated)
|
||||
}))
|
||||
s.Run("UpdateChatBuildAgentBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatBuildAgentBindingParams{
|
||||
|
||||
@@ -152,6 +152,14 @@ func (m queryMetricsStore) ActivityBumpWorkspace(ctx context.Context, arg databa
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg database.AdvanceChatRunGenerationAndUpdateStatusParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.AdvanceChatRunGenerationAndUpdateStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("AdvanceChatRunGenerationAndUpdateStatus").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AdvanceChatRunGenerationAndUpdateStatus").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.AllUserIDs(ctx, includeSystem)
|
||||
|
||||
@@ -132,6 +132,21 @@ func (mr *MockStoreMockRecorder) ActivityBumpWorkspace(ctx, arg any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivityBumpWorkspace", reflect.TypeOf((*MockStore)(nil).ActivityBumpWorkspace), ctx, arg)
|
||||
}
|
||||
|
||||
// AdvanceChatRunGenerationAndUpdateStatus mocks base method.
|
||||
func (m *MockStore) AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg database.AdvanceChatRunGenerationAndUpdateStatusParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AdvanceChatRunGenerationAndUpdateStatus", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AdvanceChatRunGenerationAndUpdateStatus indicates an expected call of AdvanceChatRunGenerationAndUpdateStatus.
|
||||
func (mr *MockStoreMockRecorder) AdvanceChatRunGenerationAndUpdateStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AdvanceChatRunGenerationAndUpdateStatus", reflect.TypeOf((*MockStore)(nil).AdvanceChatRunGenerationAndUpdateStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// AllUserIDs mocks base method.
|
||||
func (m *MockStore) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+1
@@ -1413,6 +1413,7 @@ CREATE TABLE chats (
|
||||
workspace_id uuid,
|
||||
title text DEFAULT 'New Chat'::text NOT NULL,
|
||||
status chat_status DEFAULT 'waiting'::chat_status NOT NULL,
|
||||
run_generation bigint DEFAULT 0 NOT NULL,
|
||||
worker_id uuid,
|
||||
started_at timestamp with time zone,
|
||||
heartbeat_at timestamp with time zone,
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE chats
|
||||
DROP COLUMN run_generation;
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE chats
|
||||
ADD COLUMN run_generation BIGINT NOT NULL DEFAULT 0;
|
||||
@@ -780,6 +780,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.Chat.WorkspaceID,
|
||||
&i.Chat.Title,
|
||||
&i.Chat.Status,
|
||||
&i.Chat.RunGeneration,
|
||||
&i.Chat.WorkerID,
|
||||
&i.Chat.StartedAt,
|
||||
&i.Chat.HeartbeatAt,
|
||||
|
||||
@@ -4227,6 +4227,7 @@ type Chat struct {
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
RunGeneration int64 `db:"run_generation" json:"run_generation"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
|
||||
|
||||
@@ -52,6 +52,7 @@ type sqlcQuerier interface {
|
||||
// We only bump if workspace shutdown is manual.
|
||||
// We only bump when 5% of the deadline has elapsed.
|
||||
ActivityBumpWorkspace(ctx context.Context, arg ActivityBumpWorkspaceParams) error
|
||||
AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg AdvanceChatRunGenerationAndUpdateStatusParams) (Chat, error)
|
||||
// AllUserIDs returns all UserIDs regardless of user status or deletion.
|
||||
AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error)
|
||||
ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
@@ -917,8 +918,9 @@ type sqlcQuerier interface {
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
// worker and run generation. Returns the IDs that were actually
|
||||
// updated so the caller can detect stolen, superseded, or
|
||||
// completed chats via set-difference.
|
||||
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
// Updates the cached injected context parts (AGENTS.md +
|
||||
|
||||
+140
-34
@@ -4244,7 +4244,7 @@ WHERE
|
||||
$3::int
|
||||
)
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type AcquireChatsParams struct {
|
||||
@@ -4270,6 +4270,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) (
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -4422,14 +4423,80 @@ func (q *sqlQuerier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const advanceChatRunGenerationAndUpdateStatus = `-- name: AdvanceChatRunGenerationAndUpdateStatus :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
run_generation = run_generation + 1,
|
||||
status = $1::chat_status,
|
||||
worker_id = $2::uuid,
|
||||
started_at = $3::timestamptz,
|
||||
heartbeat_at = $4::timestamptz,
|
||||
last_error = $5::text,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = $6::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type AdvanceChatRunGenerationAndUpdateStatusParams struct {
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
|
||||
LastError sql.NullString `db:"last_error" json:"last_error"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) AdvanceChatRunGenerationAndUpdateStatus(ctx context.Context, arg AdvanceChatRunGenerationAndUpdateStatusParams) (Chat, error) {
|
||||
row := q.db.QueryRowContext(ctx, advanceChatRunGenerationAndUpdateStatus,
|
||||
arg.Status,
|
||||
arg.WorkerID,
|
||||
arg.StartedAt,
|
||||
arg.HeartbeatAt,
|
||||
arg.LastError,
|
||||
arg.ID,
|
||||
)
|
||||
var i Chat
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
&i.PinOrder,
|
||||
&i.LastReadMessageID,
|
||||
&i.LastInjectedContext,
|
||||
&i.DynamicTools,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const archiveChatByID = `-- name: ArchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = $1::uuid OR root_chat_id = $1::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
)
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
FROM chats
|
||||
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
|
||||
`
|
||||
@@ -4449,6 +4516,7 @@ func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -4617,7 +4685,7 @@ func (q *sqlQuerier) DeleteOldChats(ctx context.Context, arg DeleteOldChatsParam
|
||||
}
|
||||
|
||||
const getActiveChatsByAgentID = `-- name: GetActiveChatsByAgentID :many
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
FROM chats
|
||||
WHERE agent_id = $1::uuid
|
||||
AND archived = false
|
||||
@@ -4643,6 +4711,7 @@ func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.U
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -4678,7 +4747,7 @@ func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.U
|
||||
|
||||
const getChatByID = `-- name: GetChatByID :one
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -4694,6 +4763,7 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -4718,7 +4788,7 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error
|
||||
}
|
||||
|
||||
const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats WHERE id = $1::uuid FOR UPDATE
|
||||
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats WHERE id = $1::uuid FOR UPDATE
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) {
|
||||
@@ -4730,6 +4800,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -5803,7 +5874,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u
|
||||
|
||||
const getChats = `-- name: GetChats :many
|
||||
SELECT
|
||||
chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, chats.dynamic_tools,
|
||||
chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.run_generation, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, chats.dynamic_tools,
|
||||
EXISTS (
|
||||
SELECT 1 FROM chat_messages cm
|
||||
WHERE cm.chat_id = chats.id
|
||||
@@ -5893,6 +5964,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha
|
||||
&i.Chat.WorkspaceID,
|
||||
&i.Chat.Title,
|
||||
&i.Chat.Status,
|
||||
&i.Chat.RunGeneration,
|
||||
&i.Chat.WorkerID,
|
||||
&i.Chat.StartedAt,
|
||||
&i.Chat.HeartbeatAt,
|
||||
@@ -5928,7 +6000,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha
|
||||
}
|
||||
|
||||
const getChatsByWorkspaceIDs = `-- name: GetChatsByWorkspaceIDs :many
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
FROM chats
|
||||
WHERE archived = false
|
||||
AND workspace_id = ANY($1::uuid[])
|
||||
@@ -5950,6 +6022,7 @@ func (q *sqlQuerier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -6096,7 +6169,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
|
||||
|
||||
const getStaleChats = `-- name: GetStaleChats :many
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -6125,6 +6198,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -6210,6 +6284,7 @@ INSERT INTO chats (
|
||||
title,
|
||||
mode,
|
||||
status,
|
||||
run_generation,
|
||||
mcp_server_ids,
|
||||
labels,
|
||||
dynamic_tools
|
||||
@@ -6224,12 +6299,13 @@ INSERT INTO chats (
|
||||
$8::text,
|
||||
$9::chat_mode,
|
||||
$10::chat_status,
|
||||
COALESCE($11::uuid[], '{}'::uuid[]),
|
||||
COALESCE($12::jsonb, '{}'::jsonb),
|
||||
$13::jsonb
|
||||
$11::bigint,
|
||||
COALESCE($12::uuid[], '{}'::uuid[]),
|
||||
COALESCE($13::jsonb, '{}'::jsonb),
|
||||
$14::jsonb
|
||||
)
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type InsertChatParams struct {
|
||||
@@ -6243,6 +6319,7 @@ type InsertChatParams struct {
|
||||
Title string `db:"title" json:"title"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
RunGeneration int64 `db:"run_generation" json:"run_generation"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
Labels pqtype.NullRawMessage `db:"labels" json:"labels"`
|
||||
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
|
||||
@@ -6260,6 +6337,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
|
||||
arg.Title,
|
||||
arg.Mode,
|
||||
arg.Status,
|
||||
arg.RunGeneration,
|
||||
pq.Array(arg.MCPServerIDs),
|
||||
arg.Labels,
|
||||
arg.DynamicTools,
|
||||
@@ -6271,6 +6349,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -6797,9 +6876,9 @@ WITH chats AS (
|
||||
archived = false,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1::uuid OR root_chat_id = $1::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
)
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
SELECT id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
FROM chats
|
||||
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
|
||||
`
|
||||
@@ -6823,6 +6902,7 @@ func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Cha
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -6922,7 +7002,7 @@ UPDATE chats SET
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = $3::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type UpdateChatBuildAgentBindingParams struct {
|
||||
@@ -6940,6 +7020,7 @@ func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg Update
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -6972,7 +7053,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type UpdateChatByIDParams struct {
|
||||
@@ -6989,6 +7070,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -7013,29 +7095,46 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
|
||||
}
|
||||
|
||||
const updateChatHeartbeats = `-- name: UpdateChatHeartbeats :many
|
||||
WITH input AS (
|
||||
SELECT
|
||||
($3::uuid[])[i] AS id,
|
||||
($4::bigint[])[i] AS run_generation
|
||||
FROM
|
||||
generate_subscripts($3::uuid[], 1) AS g(i)
|
||||
)
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = $1::timestamptz
|
||||
FROM
|
||||
input
|
||||
WHERE
|
||||
id = ANY($2::uuid[])
|
||||
AND worker_id = $3::uuid
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id
|
||||
chats.id = input.id
|
||||
AND chats.run_generation = input.run_generation
|
||||
AND chats.worker_id = $2::uuid
|
||||
AND chats.status = 'running'::chat_status
|
||||
RETURNING chats.id
|
||||
`
|
||||
|
||||
type UpdateChatHeartbeatsParams struct {
|
||||
Now time.Time `db:"now" json:"now"`
|
||||
IDs []uuid.UUID `db:"ids" json:"ids"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
Now time.Time `db:"now" json:"now"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
IDs []uuid.UUID `db:"ids" json:"ids"`
|
||||
RunGenerations []int64 `db:"run_generations" json:"run_generations"`
|
||||
}
|
||||
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
// worker and run generation. Returns the IDs that were actually
|
||||
// updated so the caller can detect stolen, superseded, or
|
||||
// completed chats via set-difference.
|
||||
func (q *sqlQuerier) UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats, arg.Now, pq.Array(arg.IDs), arg.WorkerID)
|
||||
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats,
|
||||
arg.Now,
|
||||
arg.WorkerID,
|
||||
pq.Array(arg.IDs),
|
||||
pq.Array(arg.RunGenerations),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -7066,7 +7165,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type UpdateChatLabelsByIDParams struct {
|
||||
@@ -7083,6 +7182,7 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -7111,7 +7211,7 @@ UPDATE chats SET
|
||||
last_injected_context = $1::jsonb
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type UpdateChatLastInjectedContextParams struct {
|
||||
@@ -7132,6 +7232,7 @@ func (q *sqlQuerier) UpdateChatLastInjectedContext(ctx context.Context, arg Upda
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -7164,7 +7265,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type UpdateChatLastModelConfigByIDParams struct {
|
||||
@@ -7181,6 +7282,7 @@ func (q *sqlQuerier) UpdateChatLastModelConfigByID(ctx context.Context, arg Upda
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -7231,7 +7333,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type UpdateChatMCPServerIDsParams struct {
|
||||
@@ -7248,6 +7350,7 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -7402,7 +7505,7 @@ SET
|
||||
WHERE
|
||||
id = $6::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type UpdateChatStatusParams struct {
|
||||
@@ -7430,6 +7533,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -7466,7 +7570,7 @@ SET
|
||||
WHERE
|
||||
id = $7::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type UpdateChatStatusPreserveUpdatedAtParams struct {
|
||||
@@ -7496,6 +7600,7 @@ func (q *sqlQuerier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
@@ -7526,7 +7631,7 @@ UPDATE chats SET
|
||||
agent_id = $3::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE id = $4::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
RETURNING id, owner_id, workspace_id, title, status, run_generation, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
`
|
||||
|
||||
type UpdateChatWorkspaceBindingParams struct {
|
||||
@@ -7550,6 +7655,7 @@ func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateC
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.RunGeneration,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
|
||||
@@ -398,6 +398,7 @@ INSERT INTO chats (
|
||||
title,
|
||||
mode,
|
||||
status,
|
||||
run_generation,
|
||||
mcp_server_ids,
|
||||
labels,
|
||||
dynamic_tools
|
||||
@@ -412,6 +413,7 @@ INSERT INTO chats (
|
||||
@title::text,
|
||||
sqlc.narg('mode')::chat_mode,
|
||||
@status::chat_status,
|
||||
@run_generation::bigint,
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb),
|
||||
sqlc.narg('dynamic_tools')::jsonb
|
||||
@@ -655,6 +657,22 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: AdvanceChatRunGenerationAndUpdateStatus :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
run_generation = run_generation + 1,
|
||||
status = @status::chat_status,
|
||||
worker_id = sqlc.narg('worker_id')::uuid,
|
||||
started_at = sqlc.narg('started_at')::timestamptz,
|
||||
heartbeat_at = sqlc.narg('heartbeat_at')::timestamptz,
|
||||
last_error = sqlc.narg('last_error')::text,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatStatusPreserveUpdatedAt :one
|
||||
UPDATE
|
||||
chats
|
||||
@@ -688,17 +706,28 @@ WHERE
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
-- provided they are still running and owned by the specified
|
||||
-- worker. Returns the IDs that were actually updated so the
|
||||
-- caller can detect stolen or completed chats via set-difference.
|
||||
-- worker and run generation. Returns the IDs that were actually
|
||||
-- updated so the caller can detect stolen, superseded, or
|
||||
-- completed chats via set-difference.
|
||||
WITH input AS (
|
||||
SELECT
|
||||
(@ids::uuid[])[i] AS id,
|
||||
(@run_generations::bigint[])[i] AS run_generation
|
||||
FROM
|
||||
generate_subscripts(@ids::uuid[], 1) AS g(i)
|
||||
)
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = @now::timestamptz
|
||||
FROM
|
||||
input
|
||||
WHERE
|
||||
id = ANY(@ids::uuid[])
|
||||
AND worker_id = @worker_id::uuid
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id;
|
||||
chats.id = input.id
|
||||
AND chats.run_generation = input.run_generation
|
||||
AND chats.worker_id = @worker_id::uuid
|
||||
AND chats.status = 'running'::chat_status
|
||||
RETURNING chats.id;
|
||||
|
||||
-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ChatControlReason identifies why a worker-scoped control message was sent.
|
||||
type ChatControlReason string
|
||||
|
||||
const (
|
||||
// ChatControlReasonInterrupt requests that the current generation stop
|
||||
// running without fencing a newer generation. This is used for explicit stop
|
||||
// actions that still allow the interrupted run to persist partial output.
|
||||
ChatControlReasonInterrupt ChatControlReason = "interrupt"
|
||||
// ChatControlReasonRestart requests that an older generation stop because a
|
||||
// newer generation has been scheduled.
|
||||
ChatControlReasonRestart ChatControlReason = "restart"
|
||||
// ChatControlReasonArchive requests that the current generation stop because
|
||||
// the chat is being archived.
|
||||
ChatControlReasonArchive ChatControlReason = "archive"
|
||||
// ChatControlReasonRecoverStale requests that the current generation stop
|
||||
// because stale recovery fenced it off.
|
||||
ChatControlReasonRecoverStale ChatControlReason = "recover_stale"
|
||||
)
|
||||
|
||||
// ChatControlChannel returns the pubsub channel for worker-scoped control
|
||||
// messages. Each worker subscribes to exactly one control channel.
|
||||
func ChatControlChannel(workerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:control:%s", workerID)
|
||||
}
|
||||
|
||||
// ChatControlMessage requests that a worker stop an active run for a chat.
|
||||
// RunGeneration fences newer runs from stale control messages.
|
||||
type ChatControlMessage struct {
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
RunGeneration int64 `json:"run_generation"`
|
||||
Reason ChatControlReason `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// HandleChatControl wraps a typed callback for ChatControlMessage payloads.
|
||||
func HandleChatControl(cb func(ctx context.Context, payload ChatControlMessage, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatControlMessage{}, err)
|
||||
return
|
||||
}
|
||||
var payload ChatControlMessage
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, ChatControlMessage{}, err)
|
||||
return
|
||||
}
|
||||
cb(ctx, payload, nil)
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,7 @@ func ChatStreamNotifyChannel(chatID uuid.UUID) string {
|
||||
|
||||
// ChatStreamNotifyMessage is the payload published on the per-chat
|
||||
// stream notification channel. Durable message content is still read
|
||||
// from the database, while transient control events can be carried
|
||||
// from the database, while lightweight status and relay hints travel
|
||||
// inline for cross-replica delivery.
|
||||
type ChatStreamNotifyMessage struct {
|
||||
// AfterMessageID tells subscribers to query messages after this
|
||||
|
||||
+310
-139
@@ -131,6 +131,7 @@ type Server struct {
|
||||
providerAPIKeys chatprovider.ProviderAPIKeys
|
||||
configCache *chatConfigCache
|
||||
configCacheUnsubscribe func()
|
||||
controlUnsubscribe func()
|
||||
|
||||
// chatStreams stores per-chat stream state. Using sync.Map
|
||||
// gives each chat independent locking — concurrent chats
|
||||
@@ -720,6 +721,7 @@ type chatStreamState struct {
|
||||
type heartbeatEntry struct {
|
||||
cancelWithCause context.CancelCauseFunc
|
||||
chatID uuid.UUID
|
||||
runGeneration int64
|
||||
workspaceID uuid.NullUUID
|
||||
logger slog.Logger
|
||||
}
|
||||
@@ -745,11 +747,12 @@ var (
|
||||
// ErrEditedMessageNotUser indicates a non-user message edit attempt.
|
||||
ErrEditedMessageNotUser = xerrors.New("only user messages can be edited")
|
||||
|
||||
// errChatTakenByOtherWorker is a sentinel used inside the
|
||||
// processChat cleanup transaction to signal that another
|
||||
// worker acquired the chat, so all post-TX side effects
|
||||
// (status publish, pubsub, web push) must be skipped.
|
||||
errChatTakenByOtherWorker = xerrors.New("chat acquired by another worker")
|
||||
// errChatSuperseded is a sentinel used inside the processChat
|
||||
// cleanup transaction to signal that a newer worker ownership or
|
||||
// run generation has superseded the local run. In that case all
|
||||
// post-TX side effects (status publish, pubsub, web push) must
|
||||
// be skipped.
|
||||
errChatSuperseded = xerrors.New("chat run superseded")
|
||||
)
|
||||
|
||||
// UsageLimitExceededError indicates the user has exceeded their chat spend
|
||||
@@ -894,8 +897,9 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
Mode: opts.ChatMode,
|
||||
// Chats created with an initial user message start pending.
|
||||
// Waiting is reserved for idle chats with no pending work.
|
||||
Status: database.ChatStatusPending,
|
||||
MCPServerIDs: opts.MCPServerIDs,
|
||||
Status: database.ChatStatusPending,
|
||||
RunGeneration: 1,
|
||||
MCPServerIDs: opts.MCPServerIDs,
|
||||
Labels: pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
@@ -1132,11 +1136,11 @@ func (p *Server) SendMessage(
|
||||
})
|
||||
|
||||
// For interrupt behavior, signal the running loop to
|
||||
// stop. setChatWaiting publishes a status notification
|
||||
// that the worker's control subscriber detects, causing
|
||||
// it to cancel with ErrInterrupted. The deferred cleanup
|
||||
// in processChat then auto-promotes the queued message
|
||||
// after persisting the partial assistant response.
|
||||
// stop. setChatWaiting publishes a worker-scoped control
|
||||
// message so the active run cancels with ErrInterrupted.
|
||||
// The deferred cleanup in processChat then auto-promotes
|
||||
// the queued message after persisting the partial
|
||||
// assistant response.
|
||||
if busyBehavior == SendMessageBusyBehaviorInterrupt {
|
||||
updatedChat, err := p.setChatWaiting(ctx, opts.ChatID)
|
||||
if err != nil {
|
||||
@@ -1210,7 +1214,12 @@ func (p *Server) EditMessage(
|
||||
return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
|
||||
}
|
||||
|
||||
var result EditMessageResult
|
||||
var (
|
||||
result EditMessageResult
|
||||
controlWorkerID uuid.UUID
|
||||
controlMsg coderdpubsub.ChatControlMessage
|
||||
publishControl bool
|
||||
)
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
|
||||
if err != nil {
|
||||
@@ -1272,7 +1281,7 @@ func (p *Server) EditMessage(
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete queued messages: %w", err)
|
||||
}
|
||||
updatedChat, err := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
updatedChat, err := tx.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
|
||||
ID: opts.ChatID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
@@ -1284,6 +1293,11 @@ func (p *Server) EditMessage(
|
||||
return xerrors.Errorf("set chat pending: %w", err)
|
||||
}
|
||||
|
||||
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
|
||||
lockedChat,
|
||||
updatedChat.RunGeneration,
|
||||
coderdpubsub.ChatControlReasonRestart,
|
||||
)
|
||||
result.Message = newMessage
|
||||
result.Chat = updatedChat
|
||||
return nil
|
||||
@@ -1291,6 +1305,9 @@ func (p *Server) EditMessage(
|
||||
if txErr != nil {
|
||||
return EditMessageResult{}, txErr
|
||||
}
|
||||
if publishControl {
|
||||
p.publishChatControl(controlWorkerID, controlMsg)
|
||||
}
|
||||
|
||||
p.publishEditedMessage(opts.ChatID, result.Message)
|
||||
p.publishEvent(opts.ChatID, codersdk.ChatStreamEvent{
|
||||
@@ -1316,9 +1333,14 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
return xerrors.New("chat_id is required")
|
||||
}
|
||||
|
||||
statusChat := chat
|
||||
interrupted := false
|
||||
var archivedChats []database.Chat
|
||||
var (
|
||||
statusChat = chat
|
||||
interrupted bool
|
||||
archivedChats []database.Chat
|
||||
controlWorkerID uuid.UUID
|
||||
controlMsg coderdpubsub.ChatControlMessage
|
||||
publishControl bool
|
||||
)
|
||||
if err := p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
||||
if err != nil {
|
||||
@@ -1341,6 +1363,11 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set chat waiting before archive: %w", err)
|
||||
}
|
||||
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
|
||||
lockedChat,
|
||||
statusChat.RunGeneration,
|
||||
coderdpubsub.ChatControlReasonArchive,
|
||||
)
|
||||
interrupted = true
|
||||
}
|
||||
|
||||
@@ -1354,6 +1381,9 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
}
|
||||
|
||||
if interrupted {
|
||||
if publishControl {
|
||||
p.publishChatControl(controlWorkerID, controlMsg)
|
||||
}
|
||||
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
|
||||
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
}
|
||||
@@ -1607,7 +1637,13 @@ func (p *Server) SubmitToolResults(
|
||||
// The GetLastChatMessageByRole lookup and all subsequent
|
||||
// validation and persistence run inside a single transaction
|
||||
// so the assistant message cannot change between reads.
|
||||
var statusConflict *ToolResultStatusConflictError
|
||||
var (
|
||||
statusConflict *ToolResultStatusConflictError
|
||||
updatedChat database.Chat
|
||||
controlWorkerID uuid.UUID
|
||||
controlMsg coderdpubsub.ChatControlMessage
|
||||
publishControl bool
|
||||
)
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
// Authoritative status check under row lock.
|
||||
locked, lockErr := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
|
||||
@@ -1761,22 +1797,33 @@ func (p *Server) SubmitToolResults(
|
||||
}
|
||||
|
||||
// Transition chat to pending.
|
||||
if _, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
updatedChat, err = tx.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
|
||||
ID: opts.ChatID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
}); updateErr != nil {
|
||||
return xerrors.Errorf("update chat status: %w", updateErr)
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update chat status: %w", err)
|
||||
}
|
||||
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
|
||||
locked,
|
||||
updatedChat.RunGeneration,
|
||||
coderdpubsub.ChatControlReasonRestart,
|
||||
)
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
if publishControl {
|
||||
p.publishChatControl(controlWorkerID, controlMsg)
|
||||
}
|
||||
p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
|
||||
// Wake the chatd run loop so it processes the chat immediately.
|
||||
p.signalWake()
|
||||
@@ -2318,7 +2365,12 @@ func (p *Server) RefreshStatus(ctx context.Context, chatID uuid.UUID) error {
|
||||
}
|
||||
|
||||
func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database.Chat, error) {
|
||||
var updatedChat database.Chat
|
||||
var (
|
||||
updatedChat database.Chat
|
||||
controlWorkerID uuid.UUID
|
||||
controlMsg coderdpubsub.ChatControlMessage
|
||||
publishControl bool
|
||||
)
|
||||
err := p.db.InTx(func(tx database.Store) error {
|
||||
locked, lockErr := tx.GetChatByIDForUpdate(ctx, chatID)
|
||||
if lockErr != nil {
|
||||
@@ -2341,11 +2393,22 @@ func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
return updateErr
|
||||
if updateErr != nil {
|
||||
return updateErr
|
||||
}
|
||||
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
|
||||
locked,
|
||||
updatedChat.RunGeneration,
|
||||
coderdpubsub.ChatControlReasonInterrupt,
|
||||
)
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if publishControl {
|
||||
p.publishChatControl(controlWorkerID, controlMsg)
|
||||
}
|
||||
p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
return updatedChat, nil
|
||||
@@ -2646,7 +2709,7 @@ func insertUserMessageAndSetPending(
|
||||
return message, lockedChat, nil
|
||||
}
|
||||
|
||||
updatedChat, err := store.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
updatedChat, err := store.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
|
||||
ID: lockedChat.ID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
@@ -2786,6 +2849,7 @@ func New(cfg Config) *Server {
|
||||
}
|
||||
p.configCacheUnsubscribe = cancelConfigSub
|
||||
}
|
||||
p.controlUnsubscribe = p.subscribeWorkerControl(ctx)
|
||||
go p.start(ctx)
|
||||
|
||||
return p
|
||||
@@ -2874,7 +2938,7 @@ func (p *Server) processOnce(ctx context.Context) {
|
||||
context.WithoutCancel(ctx), 10*time.Second,
|
||||
)
|
||||
for _, chat := range chats {
|
||||
_, updateErr := p.db.UpdateChatStatus(releaseCtx, database.UpdateChatStatusParams{
|
||||
_, updateErr := p.db.AdvanceChatRunGenerationAndUpdateStatus(releaseCtx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
@@ -3075,21 +3139,35 @@ func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) {
|
||||
func (p *Server) registerHeartbeat(entry *heartbeatEntry) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
if _, exists := p.heartbeatRegistry[entry.chatID]; exists {
|
||||
p.logger.Warn(context.Background(),
|
||||
"duplicate heartbeat registration, skipping",
|
||||
slog.F("chat_id", entry.chatID))
|
||||
return
|
||||
|
||||
if existing, exists := p.heartbeatRegistry[entry.chatID]; exists {
|
||||
if entry.runGeneration <= existing.runGeneration {
|
||||
p.logger.Warn(context.Background(),
|
||||
"duplicate heartbeat registration, skipping",
|
||||
slog.F("chat_id", entry.chatID),
|
||||
slog.F("existing_run_generation", existing.runGeneration),
|
||||
slog.F("incoming_run_generation", entry.runGeneration))
|
||||
return
|
||||
}
|
||||
|
||||
// A newer generation for the same chat can start before the old
|
||||
// processChat goroutine finishes unwinding its defers. Replace the
|
||||
// stale entry now so the new run keeps heartbeat coverage and local
|
||||
// worker-scoped control handling.
|
||||
existing.cancelWithCause(chatloop.ErrInterrupted)
|
||||
}
|
||||
|
||||
p.heartbeatRegistry[entry.chatID] = entry
|
||||
}
|
||||
|
||||
// unregisterHeartbeat removes a chat from the centralized
|
||||
// heartbeat loop when chat processing finishes.
|
||||
func (p *Server) unregisterHeartbeat(chatID uuid.UUID) {
|
||||
func (p *Server) unregisterHeartbeat(entry *heartbeatEntry) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
delete(p.heartbeatRegistry, chatID)
|
||||
if p.heartbeatRegistry[entry.chatID] == entry {
|
||||
delete(p.heartbeatRegistry, entry.chatID)
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatLoop runs in a single goroutine, issuing one batch
|
||||
@@ -3120,16 +3198,22 @@ func (p *Server) heartbeatTick(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect the IDs we believe we own.
|
||||
ids := slices.Collect(maps.Keys(snapshot))
|
||||
// Collect the IDs and generations we believe we own.
|
||||
ids := make([]uuid.UUID, 0, len(snapshot))
|
||||
runGenerations := make([]int64, 0, len(snapshot))
|
||||
for id, entry := range snapshot {
|
||||
ids = append(ids, id)
|
||||
runGenerations = append(runGenerations, entry.runGeneration)
|
||||
}
|
||||
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
|
||||
// access for batch-updating heartbeats.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: ids,
|
||||
WorkerID: p.workerID,
|
||||
Now: p.clock.Now(),
|
||||
IDs: ids,
|
||||
RunGenerations: runGenerations,
|
||||
WorkerID: p.workerID,
|
||||
Now: p.clock.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err))
|
||||
@@ -3190,7 +3274,7 @@ func (p *Server) Subscribe(
|
||||
var allCancels []func()
|
||||
allCancels = append(allCancels, localCancel)
|
||||
|
||||
// Subscribe to pubsub for durable and structured control
|
||||
// Subscribe to pubsub for durable and structured stream
|
||||
// events (status, messages, queue updates, retry, errors).
|
||||
// When pubsub is nil (e.g. in-memory
|
||||
// single-instance) we skip this and deliver all local events.
|
||||
@@ -3846,62 +3930,110 @@ func (p *Server) publishMessagePart(chatID uuid.UUID, role codersdk.ChatMessageR
|
||||
})
|
||||
}
|
||||
|
||||
func shouldCancelChatFromControlNotification(
|
||||
notify coderdpubsub.ChatStreamNotifyMessage,
|
||||
workerID uuid.UUID,
|
||||
) bool {
|
||||
status := database.ChatStatus(strings.TrimSpace(notify.Status))
|
||||
switch status {
|
||||
case database.ChatStatusWaiting, database.ChatStatusPending, database.ChatStatusError:
|
||||
func controlMessageForWorker(
|
||||
chat database.Chat,
|
||||
runGeneration int64,
|
||||
reason coderdpubsub.ChatControlReason,
|
||||
) (uuid.UUID, coderdpubsub.ChatControlMessage, bool) {
|
||||
if !chat.WorkerID.Valid {
|
||||
return uuid.Nil, coderdpubsub.ChatControlMessage{}, false
|
||||
}
|
||||
return chat.WorkerID.UUID, coderdpubsub.ChatControlMessage{
|
||||
ChatID: chat.ID,
|
||||
RunGeneration: runGeneration,
|
||||
Reason: reason,
|
||||
}, true
|
||||
}
|
||||
|
||||
func interruptsCurrentRun(reason coderdpubsub.ChatControlReason) bool {
|
||||
switch reason {
|
||||
case coderdpubsub.ChatControlReasonInterrupt, coderdpubsub.ChatControlReasonArchive:
|
||||
return true
|
||||
case database.ChatStatusRunning:
|
||||
worker := strings.TrimSpace(notify.WorkerID)
|
||||
if worker == "" {
|
||||
return false
|
||||
}
|
||||
notifyWorkerID, err := uuid.Parse(worker)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return notifyWorkerID != workerID
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) subscribeChatControl(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
cancel context.CancelCauseFunc,
|
||||
logger slog.Logger,
|
||||
) func() {
|
||||
func shouldInterruptActiveRunFromControlMessage(
|
||||
entry *heartbeatEntry,
|
||||
msg coderdpubsub.ChatControlMessage,
|
||||
) bool {
|
||||
if entry == nil || msg.ChatID != entry.chatID {
|
||||
return false
|
||||
}
|
||||
if msg.RunGeneration > entry.runGeneration {
|
||||
return true
|
||||
}
|
||||
return msg.RunGeneration == entry.runGeneration && interruptsCurrentRun(msg.Reason)
|
||||
}
|
||||
|
||||
func (p *Server) handleChatControlMessage(ctx context.Context, msg coderdpubsub.ChatControlMessage) {
|
||||
if msg.ChatID == uuid.Nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.heartbeatMu.Lock()
|
||||
entry := p.heartbeatRegistry[msg.ChatID]
|
||||
p.heartbeatMu.Unlock()
|
||||
if !shouldInterruptActiveRunFromControlMessage(entry, msg) {
|
||||
return
|
||||
}
|
||||
|
||||
entry.logger.Info(ctx, "interrupting active run from control message",
|
||||
slog.F("reason", msg.Reason),
|
||||
slog.F("incoming_run_generation", msg.RunGeneration),
|
||||
slog.F("local_run_generation", entry.runGeneration),
|
||||
)
|
||||
entry.cancelWithCause(chatloop.ErrInterrupted)
|
||||
}
|
||||
|
||||
func (p *Server) publishChatControl(workerID uuid.UUID, msg coderdpubsub.ChatControlMessage) {
|
||||
if workerID == uuid.Nil {
|
||||
return
|
||||
}
|
||||
if workerID == p.workerID {
|
||||
p.handleChatControlMessage(context.Background(), msg)
|
||||
return
|
||||
}
|
||||
if p.pubsub == nil {
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
p.logger.Error(context.Background(), "failed to marshal chat control message",
|
||||
slog.F("chat_id", msg.ChatID),
|
||||
slog.F("worker_id", workerID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatControlChannel(workerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat control message",
|
||||
slog.F("chat_id", msg.ChatID),
|
||||
slog.F("worker_id", workerID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) subscribeWorkerControl(ctx context.Context) func() {
|
||||
if p.pubsub == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
listener := func(_ context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "chat control pubsub error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
var notify coderdpubsub.ChatStreamNotifyMessage
|
||||
if unmarshalErr := json.Unmarshal(message, ¬ify); unmarshalErr != nil {
|
||||
logger.Warn(ctx, "failed to unmarshal chat control notify", slog.Error(unmarshalErr))
|
||||
return
|
||||
}
|
||||
|
||||
if shouldCancelChatFromControlNotification(notify, p.workerID) {
|
||||
cancel(chatloop.ErrInterrupted)
|
||||
}
|
||||
}
|
||||
|
||||
controlCancel, err := p.pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatStreamNotifyChannel(chatID),
|
||||
listener,
|
||||
coderdpubsub.ChatControlChannel(p.workerID),
|
||||
coderdpubsub.HandleChatControl(func(ctx context.Context, msg coderdpubsub.ChatControlMessage, err error) {
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "chat control pubsub error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
p.handleChatControlMessage(ctx, msg)
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to subscribe to chat control notifications", slog.Error(err))
|
||||
p.logger.Error(ctx, "failed to subscribe to worker chat controls", slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
return controlCancel
|
||||
@@ -4023,6 +4155,21 @@ func (p *Server) trackWorkspaceUsage(
|
||||
return wsID
|
||||
}
|
||||
|
||||
func (p *Server) shouldStartOwnedRun(ctx context.Context, chat database.Chat) bool {
|
||||
latest, err := p.db.GetChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to verify chat ownership before start",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return true
|
||||
}
|
||||
return latest.RunGeneration == chat.RunGeneration &&
|
||||
latest.Status == database.ChatStatusRunning &&
|
||||
latest.WorkerID.Valid &&
|
||||
latest.WorkerID.UUID == p.workerID
|
||||
}
|
||||
|
||||
func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
logger := p.logger.With(slog.F("chat_id", chat.ID))
|
||||
logger.Info(ctx, "processing chat request")
|
||||
@@ -4030,42 +4177,30 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
chatCtx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
// Gate the control subscriber behind a channel that is closed
|
||||
// after we publish "running" status. This prevents stale
|
||||
// pubsub notifications (e.g. the "pending" notification from
|
||||
// SendMessage that triggered this processing) from
|
||||
// interrupting us before we start work. Due to async
|
||||
// PostgreSQL NOTIFY delivery, a notification published before
|
||||
// subscribeChatControl registers its queue can still arrive
|
||||
// after registration.
|
||||
controlArmed := make(chan struct{})
|
||||
gatedCancel := func(cause error) {
|
||||
select {
|
||||
case <-controlArmed:
|
||||
cancel(cause)
|
||||
default:
|
||||
logger.Debug(ctx, "ignoring control notification before armed")
|
||||
}
|
||||
}
|
||||
|
||||
controlCancel := p.subscribeChatControl(chatCtx, chat.ID, gatedCancel, logger)
|
||||
defer func() {
|
||||
if controlCancel != nil {
|
||||
controlCancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// Register with the centralized heartbeat loop instead of
|
||||
// running a per-chat goroutine. The loop issues a single batch
|
||||
// UPDATE for all chats on this worker and detects stolen chats
|
||||
// via set-difference.
|
||||
p.registerHeartbeat(&heartbeatEntry{
|
||||
heartbeat := &heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chat.ID,
|
||||
runGeneration: chat.RunGeneration,
|
||||
workspaceID: chat.WorkspaceID,
|
||||
logger: logger,
|
||||
})
|
||||
defer p.unregisterHeartbeat(chat.ID)
|
||||
}
|
||||
p.registerHeartbeat(heartbeat)
|
||||
defer p.unregisterHeartbeat(heartbeat)
|
||||
|
||||
// Re-check ownership after registration so a mutation that won the
|
||||
// race against AcquireChats cannot start work simply because its
|
||||
// control message arrived before the local run was registered.
|
||||
if !p.shouldStartOwnedRun(context.WithoutCancel(chatCtx), chat) {
|
||||
logger.Info(ctx, "chat no longer owned before processing start",
|
||||
slog.F("run_generation", chat.RunGeneration),
|
||||
)
|
||||
cancel(chatloop.ErrInterrupted)
|
||||
return
|
||||
}
|
||||
|
||||
// Start buffering stream events BEFORE publishing the running
|
||||
// status. This closes a race where a subscriber sees
|
||||
@@ -4098,12 +4233,6 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
Valid: true,
|
||||
})
|
||||
|
||||
// Arm the control subscriber. Closing the channel is a
|
||||
// happens-before guarantee in the Go memory model — any
|
||||
// notification dispatched after this point will correctly
|
||||
// interrupt processing.
|
||||
close(controlArmed)
|
||||
|
||||
// Determine the final status and last error to set when we're done.
|
||||
status := database.ChatStatusWaiting
|
||||
wasInterrupted := false
|
||||
@@ -4145,21 +4274,26 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
return xerrors.Errorf("lock chat for release: %w", lockErr)
|
||||
}
|
||||
|
||||
// If another worker has already acquired this chat,
|
||||
// bail out — we must not overwrite their running
|
||||
// status or publish spurious events.
|
||||
// If a newer run generation or another worker already owns
|
||||
// this chat, bail out — we must not overwrite their status
|
||||
// or publish spurious side effects.
|
||||
if latestChat.RunGeneration != chat.RunGeneration {
|
||||
return errChatSuperseded
|
||||
}
|
||||
if latestChat.Status == database.ChatStatusRunning &&
|
||||
latestChat.WorkerID.Valid &&
|
||||
latestChat.WorkerID.UUID != p.workerID {
|
||||
return errChatTakenByOtherWorker
|
||||
return errChatSuperseded
|
||||
}
|
||||
|
||||
// If someone else already set the chat to pending (e.g.
|
||||
// the promote endpoint), don't overwrite it — just clear
|
||||
// the worker and let the processor pick it back up.
|
||||
// during a transition from older code), don't overwrite it.
|
||||
if latestChat.Status == database.ChatStatusPending {
|
||||
status = database.ChatStatusPending
|
||||
} else if status == database.ChatStatusWaiting && !latestChat.Archived {
|
||||
status = latestChat.Status
|
||||
updatedChat = latestChat
|
||||
return nil
|
||||
}
|
||||
if status == database.ChatStatusWaiting && !latestChat.Archived {
|
||||
// Queued messages were already admitted through SendMessage,
|
||||
// so auto-promotion only preserves FIFO order here. Archived
|
||||
// chats skip promotion so archiving behaves like a hard stop.
|
||||
@@ -4173,20 +4307,31 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
}
|
||||
|
||||
var updateErr error
|
||||
updatedChat, updateErr = tx.UpdateChatStatus(cleanupCtx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: status,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{String: lastError, Valid: lastError != ""},
|
||||
})
|
||||
if status == database.ChatStatusPending {
|
||||
updatedChat, updateErr = tx.AdvanceChatRunGenerationAndUpdateStatus(cleanupCtx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: status,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{String: lastError, Valid: lastError != ""},
|
||||
})
|
||||
} else {
|
||||
updatedChat, updateErr = tx.UpdateChatStatus(cleanupCtx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: status,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{String: lastError, Valid: lastError != ""},
|
||||
})
|
||||
}
|
||||
return updateErr
|
||||
}, nil)
|
||||
if errors.Is(err, errChatTakenByOtherWorker) {
|
||||
// Another worker owns this chat now — skip all
|
||||
// post-TX side effects (status publish, pubsub,
|
||||
// web push) to avoid overwriting their state.
|
||||
if errors.Is(err, errChatSuperseded) {
|
||||
// A newer run owns this chat now — skip all post-TX side
|
||||
// effects (status publish, pubsub, web push) to avoid
|
||||
// overwriting their state.
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -4800,13 +4945,17 @@ func (p *Server) runChat(
|
||||
// already been cleared but we still want to persist
|
||||
// the partial assistant response. We allow the write
|
||||
// because the history has NOT been truncated — the
|
||||
// user simply asked to stop. In contrast, EditMessage
|
||||
// sets the chat to "pending" after truncating, so the
|
||||
// pending check still correctly blocks stale writes.
|
||||
// user simply asked to stop. In contrast, generation-
|
||||
// advancing transitions such as EditMessage move the
|
||||
// chat to a newer run generation, which blocks stale
|
||||
// writes even if the old worker is still draining.
|
||||
lockedChat, lockErr := tx.GetChatByIDForUpdate(persistCtx, chat.ID)
|
||||
if lockErr != nil {
|
||||
return xerrors.Errorf("lock chat for persist: %w", lockErr)
|
||||
}
|
||||
if lockedChat.RunGeneration != chat.RunGeneration {
|
||||
return chatloop.ErrInterrupted
|
||||
}
|
||||
if !lockedChat.WorkerID.Valid || lockedChat.WorkerID.UUID != p.workerID {
|
||||
// The worker_id was cleared. Only allow the persist
|
||||
// if the chat transitioned to "waiting" (interrupt),
|
||||
@@ -5914,6 +6063,12 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
|
||||
// between GetStaleChats (a bare SELECT) and here, the chat's
|
||||
// heartbeat may have been refreshed. We re-check freshness
|
||||
// under the row lock before resetting.
|
||||
var (
|
||||
updatedChat database.Chat
|
||||
controlWorkerID uuid.UUID
|
||||
controlMsg coderdpubsub.ChatControlMessage
|
||||
publishControl bool
|
||||
)
|
||||
err := p.db.InTx(func(tx database.Store) error {
|
||||
locked, lockErr := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
||||
if lockErr != nil {
|
||||
@@ -5980,7 +6135,8 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
|
||||
|
||||
// Reset so any replica can pick it up (pending) or
|
||||
// the client sees the failure (error).
|
||||
_, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
var updateErr error
|
||||
updatedChat, updateErr = tx.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: recoverStatus,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
@@ -5991,13 +6147,24 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
|
||||
if updateErr != nil {
|
||||
return updateErr
|
||||
}
|
||||
controlWorkerID, controlMsg, publishControl = controlMessageForWorker(
|
||||
locked,
|
||||
updatedChat.RunGeneration,
|
||||
coderdpubsub.ChatControlReasonRecoverStale,
|
||||
)
|
||||
recovered++
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "failed to recover stale chat",
|
||||
slog.F("chat_id", chat.ID), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if publishControl {
|
||||
p.publishChatControl(controlWorkerID, controlMsg)
|
||||
}
|
||||
p.publishStatus(updatedChat.ID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
if recovered > 0 {
|
||||
@@ -6203,6 +6370,10 @@ func (p *Server) Close() error {
|
||||
p.configCacheUnsubscribe = nil
|
||||
unsub()
|
||||
}
|
||||
if unsub := p.controlUnsubscribe; unsub != nil {
|
||||
p.controlUnsubscribe = nil
|
||||
unsub()
|
||||
}
|
||||
p.cancel()
|
||||
<-p.closed
|
||||
p.drainInflight()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -2580,26 +2581,219 @@ func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessChat_IgnoresStaleControlNotification verifies that
|
||||
// processChat is not interrupted by a "pending" notification
|
||||
// published before processing begins. This is the race that caused
|
||||
// TestOpenAIReasoningWithWebSearchRoundTripStoreFalse to flake:
|
||||
// SendMessage publishes "pending" via PostgreSQL NOTIFY, and due
|
||||
// to async delivery the notification can arrive at the control
|
||||
// subscriber after it registers but before the processor publishes
|
||||
// "running".
|
||||
func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
func TestShouldInterruptActiveRunFromControlMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
entry := &heartbeatEntry{chatID: chatID, runGeneration: 7}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
entry *heartbeatEntry
|
||||
msg coderdpubsub.ChatControlMessage
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "newer generation restarts",
|
||||
entry: entry,
|
||||
msg: coderdpubsub.ChatControlMessage{
|
||||
ChatID: chatID,
|
||||
RunGeneration: 8,
|
||||
Reason: coderdpubsub.ChatControlReasonRestart,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "equal generation interrupt cancels",
|
||||
entry: entry,
|
||||
msg: coderdpubsub.ChatControlMessage{
|
||||
ChatID: chatID,
|
||||
RunGeneration: 7,
|
||||
Reason: coderdpubsub.ChatControlReasonInterrupt,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "equal generation archive cancels",
|
||||
entry: entry,
|
||||
msg: coderdpubsub.ChatControlMessage{
|
||||
ChatID: chatID,
|
||||
RunGeneration: 7,
|
||||
Reason: coderdpubsub.ChatControlReasonArchive,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "equal generation restart is ignored",
|
||||
entry: entry,
|
||||
msg: coderdpubsub.ChatControlMessage{
|
||||
ChatID: chatID,
|
||||
RunGeneration: 7,
|
||||
Reason: coderdpubsub.ChatControlReasonRestart,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "older generation interrupt is ignored",
|
||||
entry: entry,
|
||||
msg: coderdpubsub.ChatControlMessage{
|
||||
ChatID: chatID,
|
||||
RunGeneration: 6,
|
||||
Reason: coderdpubsub.ChatControlReasonInterrupt,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "different chat is ignored",
|
||||
entry: entry,
|
||||
msg: coderdpubsub.ChatControlMessage{
|
||||
ChatID: uuid.New(),
|
||||
RunGeneration: 8,
|
||||
Reason: coderdpubsub.ChatControlReasonRestart,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "missing entry is ignored",
|
||||
entry: nil,
|
||||
msg: coderdpubsub.ChatControlMessage{
|
||||
ChatID: chatID,
|
||||
RunGeneration: 8,
|
||||
Reason: coderdpubsub.ChatControlReasonRestart,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.want, shouldInterruptActiveRunFromControlMessage(tt.entry, tt.msg))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeWorkerControl_CancelsRegisteredRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ps := dbpubsub.NewInMemory()
|
||||
|
||||
chatID := uuid.New()
|
||||
workerID := uuid.New()
|
||||
runGeneration := int64(7)
|
||||
|
||||
server := &Server{
|
||||
logger: logger,
|
||||
pubsub: ps,
|
||||
workerID: workerID,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
chatCtx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
entry := &heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chatID,
|
||||
runGeneration: runGeneration,
|
||||
logger: logger,
|
||||
}
|
||||
server.registerHeartbeat(entry)
|
||||
defer server.unregisterHeartbeat(entry)
|
||||
|
||||
controlCancel := server.subscribeWorkerControl(ctx)
|
||||
require.NotNil(t, controlCancel)
|
||||
defer controlCancel()
|
||||
|
||||
payload, err := json.Marshal(coderdpubsub.ChatControlMessage{
|
||||
ChatID: chatID,
|
||||
RunGeneration: runGeneration + 1,
|
||||
Reason: coderdpubsub.ChatControlReasonRestart,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, ps.Publish(coderdpubsub.ChatControlChannel(workerID), payload))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return errors.Is(context.Cause(chatCtx), chatloop.ErrInterrupted)
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestRegisterHeartbeat_ReplacesOlderGenerationAndIgnoresStaleUnregister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chatID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
logger: logger,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
oldCtx, oldCancel := context.WithCancelCause(ctx)
|
||||
defer oldCancel(nil)
|
||||
oldEntry := &heartbeatEntry{
|
||||
cancelWithCause: oldCancel,
|
||||
chatID: chatID,
|
||||
runGeneration: 1,
|
||||
logger: logger,
|
||||
}
|
||||
server.registerHeartbeat(oldEntry)
|
||||
|
||||
newCtx, newCancel := context.WithCancelCause(ctx)
|
||||
defer newCancel(nil)
|
||||
newEntry := &heartbeatEntry{
|
||||
cancelWithCause: newCancel,
|
||||
chatID: chatID,
|
||||
runGeneration: 2,
|
||||
logger: logger,
|
||||
}
|
||||
server.registerHeartbeat(newEntry)
|
||||
|
||||
require.ErrorIs(t, context.Cause(oldCtx), chatloop.ErrInterrupted)
|
||||
require.NoError(t, context.Cause(newCtx))
|
||||
|
||||
server.heartbeatMu.Lock()
|
||||
require.Same(t, newEntry, server.heartbeatRegistry[chatID])
|
||||
server.heartbeatMu.Unlock()
|
||||
|
||||
server.unregisterHeartbeat(oldEntry)
|
||||
|
||||
server.heartbeatMu.Lock()
|
||||
require.Same(t, newEntry, server.heartbeatRegistry[chatID])
|
||||
server.heartbeatMu.Unlock()
|
||||
|
||||
server.unregisterHeartbeat(newEntry)
|
||||
|
||||
server.heartbeatMu.Lock()
|
||||
_, exists := server.heartbeatRegistry[chatID]
|
||||
server.heartbeatMu.Unlock()
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
// TestProcessChat_IgnoresStaleStatusNotification verifies that
|
||||
// processChat is not interrupted by a stale "pending" status
|
||||
// fanout delivered after the worker has already published
|
||||
// "running". Worker control now lives on a separate per-worker
|
||||
// channel, so delayed status fanout is observational only.
|
||||
func TestProcessChat_IgnoresStaleStatusNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
ps := dbpubsub.NewInMemory()
|
||||
ps := chattest.NewDelayedStatusPubsub(dbpubsub.NewInMemory())
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
chatID := uuid.New()
|
||||
workerID := uuid.New()
|
||||
runGeneration := int64(7)
|
||||
chatChannel := coderdpubsub.ChatStreamNotifyChannel(chatID)
|
||||
ps.DelayStatus(chatChannel, string(database.ChatStatusPending))
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
@@ -2612,42 +2806,52 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
// Publish a stale "pending" notification on the control channel
|
||||
// BEFORE processChat subscribes. In production this is the
|
||||
// notification from SendMessage that triggered the processing.
|
||||
staleNotify, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
|
||||
Status: string(database.ChatStatusPending),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), staleNotify)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, ps.Publish(chatChannel, staleNotify))
|
||||
|
||||
// Track which status processChat writes during cleanup.
|
||||
var finalStatus database.ChatStatus
|
||||
cleanupDone := make(chan struct{})
|
||||
allowModelResolution := make(chan struct{})
|
||||
|
||||
// The deferred cleanup in processChat runs a transaction.
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
|
||||
database.Chat{
|
||||
ID: chatID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
||||
RunGeneration: runGeneration,
|
||||
}, nil,
|
||||
)
|
||||
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
||||
return fn(db)
|
||||
},
|
||||
)
|
||||
db.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(
|
||||
database.Chat{ID: chatID, Status: database.ChatStatusRunning, WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}}, nil,
|
||||
database.Chat{
|
||||
ID: chatID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
||||
RunGeneration: runGeneration,
|
||||
}, nil,
|
||||
)
|
||||
db.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, params database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
finalStatus = params.Status
|
||||
close(cleanupDone)
|
||||
return database.Chat{ID: chatID, Status: params.Status}, nil
|
||||
return database.Chat{ID: chatID, Status: params.Status, RunGeneration: runGeneration}, nil
|
||||
},
|
||||
)
|
||||
|
||||
// resolveChatModel fails immediately — that's fine, we only
|
||||
// need processChat to get past initialization without being
|
||||
// interrupted by the stale notification.
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return(
|
||||
database.ChatModelConfig{}, xerrors.New("no model configured"),
|
||||
// resolveChatModel fails after the stale notification has been
|
||||
// released. That ensures the test exercises delayed status fanout
|
||||
// while the chat is already running.
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
|
||||
<-allowModelResolution
|
||||
return database.ChatModelConfig{}, xerrors.New("no model configured")
|
||||
},
|
||||
).AnyTimes()
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
@@ -2656,19 +2860,19 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
).AnyTimes()
|
||||
db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes()
|
||||
|
||||
chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New()}
|
||||
chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New(), RunGeneration: runGeneration}
|
||||
go server.processChat(ctx, chat)
|
||||
|
||||
require.NoError(t, ps.WaitForStatusPublish(ctx, chatChannel, string(database.ChatStatusRunning)))
|
||||
require.NoError(t, ps.ReleaseStatus(chatChannel, string(database.ChatStatusPending)))
|
||||
close(allowModelResolution)
|
||||
|
||||
select {
|
||||
case <-cleanupDone:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("processChat did not complete")
|
||||
}
|
||||
|
||||
// If the stale notification interrupted us, status would be
|
||||
// "waiting" (the ErrInterrupted path). Since the gate blocked
|
||||
// it, processChat reached runChat, which failed on model
|
||||
// resolution → status is "error".
|
||||
require.Equal(t, database.ChatStatusError, finalStatus,
|
||||
"processChat should have reached runChat (error), not been interrupted (waiting)")
|
||||
}
|
||||
@@ -2710,16 +2914,19 @@ func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel1,
|
||||
chatID: chat1,
|
||||
runGeneration: 1,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel2,
|
||||
chatID: chat2,
|
||||
runGeneration: 2,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel3,
|
||||
chatID: chat3,
|
||||
runGeneration: 3,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
@@ -2729,6 +2936,12 @@ func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
|
||||
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
require.Equal(t, workerID, params.WorkerID)
|
||||
require.Len(t, params.IDs, 3)
|
||||
require.Len(t, params.RunGenerations, 3)
|
||||
got := make(map[uuid.UUID]int64, len(params.IDs))
|
||||
for i, id := range params.IDs {
|
||||
got[id] = params.RunGenerations[i]
|
||||
}
|
||||
require.Equal(t, map[uuid.UUID]int64{chat1: 1, chat2: 2, chat3: 3}, got)
|
||||
// Return only chat1 and chat2 as surviving.
|
||||
return []uuid.UUID{chat1, chat2}, nil
|
||||
},
|
||||
@@ -2784,6 +2997,7 @@ func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chatID,
|
||||
runGeneration: 11,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
|
||||
@@ -474,7 +474,7 @@ func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
|
||||
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
|
||||
}
|
||||
|
||||
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
|
||||
func TestUpdateChatHeartbeatsRequiresOwnershipAndGeneration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -503,24 +503,67 @@ func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
|
||||
|
||||
// Wrong worker_id should return no IDs.
|
||||
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
RunGenerations: []int64{chat.RunGeneration},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, ids)
|
||||
|
||||
// Correct worker_id should return the chat's ID.
|
||||
// Wrong generation should also return no IDs.
|
||||
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
WorkerID: workerID,
|
||||
Now: time.Now(),
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
RunGenerations: []int64{chat.RunGeneration + 1},
|
||||
WorkerID: workerID,
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, ids)
|
||||
|
||||
// Correct worker_id and generation should return the chat's ID.
|
||||
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
RunGenerations: []int64{chat.RunGeneration},
|
||||
WorkerID: workerID,
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, ids, 1)
|
||||
require.Equal(t, chat.ID, ids[0])
|
||||
}
|
||||
|
||||
func TestAdvanceChatRunGenerationAndUpdateStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "generation-advance",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, chat.RunGeneration)
|
||||
|
||||
updated, err := db.AdvanceChatRunGenerationAndUpdateStatus(ctx, database.AdvanceChatRunGenerationAndUpdateStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chat.RunGeneration+1, updated.RunGeneration)
|
||||
require.Equal(t, database.ChatStatusPending, updated.Status)
|
||||
}
|
||||
|
||||
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
package chattest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
)
|
||||
|
||||
type delayedStatusKey struct {
|
||||
event string
|
||||
status string
|
||||
}
|
||||
|
||||
// DelayedStatusPubsub buffers selected status notifications until a test
|
||||
// releases them. This lets tests exercise stale-notify races deterministically
|
||||
// without depending on PostgreSQL delivery timing.
|
||||
type DelayedStatusPubsub struct {
|
||||
inner dbpubsub.Pubsub
|
||||
|
||||
mu sync.Mutex
|
||||
delayEnabled map[delayedStatusKey]bool
|
||||
delayedMessages map[delayedStatusKey][][]byte
|
||||
subscribed map[string]bool
|
||||
subscribeWait map[string]chan struct{}
|
||||
published map[delayedStatusKey]bool
|
||||
publishWait map[delayedStatusKey]chan struct{}
|
||||
}
|
||||
|
||||
// NewDelayedStatusPubsub wraps a pubsub implementation with deterministic
|
||||
// buffering for chosen status notifications.
|
||||
func NewDelayedStatusPubsub(inner dbpubsub.Pubsub) *DelayedStatusPubsub {
|
||||
return &DelayedStatusPubsub{
|
||||
inner: inner,
|
||||
delayEnabled: make(map[delayedStatusKey]bool),
|
||||
delayedMessages: make(map[delayedStatusKey][][]byte),
|
||||
subscribed: make(map[string]bool),
|
||||
subscribeWait: make(map[string]chan struct{}),
|
||||
published: make(map[delayedStatusKey]bool),
|
||||
publishWait: make(map[delayedStatusKey]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// DelayStatus starts buffering matching status notifications instead of
|
||||
// publishing them immediately.
|
||||
func (p *DelayedStatusPubsub) DelayStatus(event, status string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.delayEnabled[delayedStatusKey{event: event, status: status}] = true
|
||||
}
|
||||
|
||||
// ReleaseStatus flushes the buffered notifications for an event/status pair in
|
||||
// publish order and resumes immediate delivery for later notifications.
|
||||
func (p *DelayedStatusPubsub) ReleaseStatus(event, status string) error {
|
||||
key := delayedStatusKey{event: event, status: status}
|
||||
|
||||
p.mu.Lock()
|
||||
delete(p.delayEnabled, key)
|
||||
messages := p.delayedMessages[key]
|
||||
delete(p.delayedMessages, key)
|
||||
p.mu.Unlock()
|
||||
|
||||
for _, message := range messages {
|
||||
if err := p.inner.Publish(event, message); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WaitForSubscribe blocks until a subscriber registers for the event.
|
||||
func (p *DelayedStatusPubsub) WaitForSubscribe(ctx context.Context, event string) error {
|
||||
wait := p.subscribeWaiter(event)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-wait:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForStatusPublish blocks until a matching status notification is
|
||||
// published, whether or not it is currently buffered.
|
||||
func (p *DelayedStatusPubsub) WaitForStatusPublish(ctx context.Context, event, status string) error {
|
||||
wait := p.publishWaiter(delayedStatusKey{event: event, status: status})
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-wait:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) Subscribe(event string, listener dbpubsub.Listener) (func(), error) {
|
||||
cancel, err := p.inner.Subscribe(event, listener)
|
||||
if err == nil {
|
||||
p.markSubscribed(event)
|
||||
}
|
||||
return cancel, err
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) SubscribeWithErr(event string, listener dbpubsub.ListenerWithErr) (func(), error) {
|
||||
cancel, err := p.inner.SubscribeWithErr(event, listener)
|
||||
if err == nil {
|
||||
p.markSubscribed(event)
|
||||
}
|
||||
return cancel, err
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) Publish(event string, message []byte) error {
|
||||
status, ok := chatNotifyStatus(message)
|
||||
if !ok {
|
||||
return p.inner.Publish(event, message)
|
||||
}
|
||||
|
||||
key := delayedStatusKey{event: event, status: status}
|
||||
p.markPublished(key)
|
||||
|
||||
p.mu.Lock()
|
||||
delay := p.delayEnabled[key]
|
||||
if delay {
|
||||
p.delayedMessages[key] = append(p.delayedMessages[key], append([]byte(nil), message...))
|
||||
}
|
||||
p.mu.Unlock()
|
||||
if delay {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.inner.Publish(event, message)
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) Close() error {
|
||||
return p.inner.Close()
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) markSubscribed(event string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.subscribed[event] {
|
||||
return
|
||||
}
|
||||
p.subscribed[event] = true
|
||||
if wait, ok := p.subscribeWait[event]; ok {
|
||||
close(wait)
|
||||
delete(p.subscribeWait, event)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) markPublished(key delayedStatusKey) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.published[key] {
|
||||
return
|
||||
}
|
||||
p.published[key] = true
|
||||
if wait, ok := p.publishWait[key]; ok {
|
||||
close(wait)
|
||||
delete(p.publishWait, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) subscribeWaiter(event string) chan struct{} {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
return p.subscribeWaiterLocked(event)
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) subscribeWaiterLocked(event string) chan struct{} {
|
||||
if p.subscribed[event] {
|
||||
return closedSignal()
|
||||
}
|
||||
wait, ok := p.subscribeWait[event]
|
||||
if !ok {
|
||||
wait = make(chan struct{})
|
||||
p.subscribeWait[event] = wait
|
||||
}
|
||||
return wait
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) publishWaiter(key delayedStatusKey) chan struct{} {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
return p.publishWaiterLocked(key)
|
||||
}
|
||||
|
||||
func (p *DelayedStatusPubsub) publishWaiterLocked(key delayedStatusKey) chan struct{} {
|
||||
if p.published[key] {
|
||||
return closedSignal()
|
||||
}
|
||||
wait, ok := p.publishWait[key]
|
||||
if !ok {
|
||||
wait = make(chan struct{})
|
||||
p.publishWait[key] = wait
|
||||
}
|
||||
return wait
|
||||
}
|
||||
|
||||
func chatNotifyStatus(message []byte) (string, bool) {
|
||||
var notify coderdpubsub.ChatStreamNotifyMessage
|
||||
if err := json.Unmarshal(message, ¬ify); err != nil {
|
||||
return "", false
|
||||
}
|
||||
if notify.Status == "" {
|
||||
return "", false
|
||||
}
|
||||
return notify.Status, true
|
||||
}
|
||||
|
||||
func closedSignal() chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
@@ -470,6 +470,7 @@ func (p *Server) createChildSubagentChatWithOptions(
|
||||
Title: title,
|
||||
Mode: opts.chatMode,
|
||||
Status: database.ChatStatusPending,
|
||||
RunGeneration: 1,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
|
||||
@@ -1425,12 +1425,16 @@ func TestSubscribeRelayDialCanceledOnFastCompletion(t *testing.T) {
|
||||
func TestSubscribeRelayEstablishedMidStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
db, rawPS := dbtestutil.NewDB(t)
|
||||
ps := chattest.NewDelayedStatusPubsub(rawPS)
|
||||
workerID := uuid.New()
|
||||
subscriberID := uuid.New()
|
||||
|
||||
// Gate: worker blocks after first streaming request until we
|
||||
// release it. This gives the relay time to establish.
|
||||
// Gate: the worker cannot emit the first streaming chunk until we
|
||||
// release it. This lets the test deliver a stale pending
|
||||
// notification after the run is already current but before the LLM
|
||||
// stream starts.
|
||||
allowStreamingStart := make(chan struct{})
|
||||
firstChunkEmitted := make(chan struct{})
|
||||
continueStreaming := make(chan struct{})
|
||||
|
||||
@@ -1438,8 +1442,16 @@ func TestSubscribeRelayEstablishedMidStream(t *testing.T) {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("mid-stream-relay")
|
||||
}
|
||||
// Signal that the first streaming request was received,
|
||||
// then block until released.
|
||||
select {
|
||||
case <-allowStreamingStart:
|
||||
case <-req.Context().Done():
|
||||
return chattest.OpenAIErrorResponse(http.StatusRequestTimeout, "request_canceled", "request canceled before streaming")
|
||||
}
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
return chattest.OpenAIErrorResponse(http.StatusRequestTimeout, "request_canceled", "request canceled before first chunk")
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-firstChunkEmitted:
|
||||
default:
|
||||
@@ -1501,6 +1513,8 @@ func TestSubscribeRelayEstablishedMidStream(t *testing.T) {
|
||||
|
||||
// Create the chat in waiting state.
|
||||
chat := seedWaitingChat(ctx, t, db, user, model, "mid-stream-relay")
|
||||
chatChannel := coderdpubsub.ChatStreamNotifyChannel(chat.ID)
|
||||
ps.DelayStatus(chatChannel, string(database.ChatStatusPending))
|
||||
|
||||
// Subscribe from the subscriber replica while the chat is idle.
|
||||
_, events, subCancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
||||
@@ -1515,6 +1529,23 @@ func TestSubscribeRelayEstablishedMidStream(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the subscriber to receive the running status, which
|
||||
// proves the worker has started the current run. Release the
|
||||
// delayed pending notification only after that point so it is
|
||||
// definitively stale.
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
return event.Type == codersdk.ChatStreamEventTypeStatus &&
|
||||
event.Status != nil &&
|
||||
event.Status.Status == codersdk.ChatStatusRunning
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
require.NoError(t, ps.ReleaseStatus(chatChannel, string(database.ChatStatusPending)))
|
||||
close(allowStreamingStart)
|
||||
|
||||
// Wait for the worker to reach the LLM (first streaming
|
||||
// request). Also poll the chat status so we fail fast with a
|
||||
// clear message if the worker errors out instead of timing
|
||||
@@ -1543,20 +1574,6 @@ waitForStream:
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the subscriber to receive the running status, which
|
||||
// triggers the relay. Because the dialer is non-blocking, the
|
||||
// relay establishes promptly.
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
return event.Type == codersdk.ChatStreamEventTypeStatus &&
|
||||
event.Status != nil &&
|
||||
event.Status.Status == codersdk.ChatStatusRunning
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
// Now release the worker to continue streaming.
|
||||
close(continueStreaming)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user