fix: target specific chat in MarkStale instead of broadcasting to all workspace chats (#23883)
## Problem Subagent chats were receiving git context (branch, remote origin, PR status) from their parent or sibling chats' git operations. When a git operation triggers external auth, the workspace agent sends `chat_id` identifying which chat initiated it — but this was broken at two levels: 1. **Agent side:** `CODER_CHAT_ID` was never injected into process environments. `chatd` sets `Coder-Chat-Id` HTTP headers and the agent extracts them for process isolation, but never propagated `CODER_CHAT_ID` to `cmd.Env`. So `gitaskpass` always sent an empty `chat_id`. 2. **Server side:** `workspaceAgentsExternalAuth` ignored the `chat_id` query param. `MarkStale` broadcast git context to **all** chats on the workspace via `filterChatsByWorkspaceID`. ## Fix - Inject `CODER_CHAT_ID` into `cmd.Env` in `agentproc` when the chat ID is known, so `gitaskpass` can read and forward it. - Read `chat_id` from query params in `workspaceAgentsExternalAuth` and thread it through `chatGitRef`. - Refactor `MarkStale` to accept a `MarkStaleParams` struct. When `ChatID` is provided, target only that specific chat. When empty (legacy agents, non-chat git operations), fall back to the existing workspace-wide broadcast. - Extract `markStaleSingle` helper to deduplicate the upsert+publish logic. <details><summary>Investigation notes</summary> ### Data flow before fix ``` chatd → sets Coder-Chat-Id header on agent conn agent → extracts chatID, stores on process struct agent → does NOT set CODER_CHAT_ID in cmd.Env ← gap 1 gitaskpass → reads CODER_CHAT_ID (always empty), sends chat_id="" server handler → ignores chat_id query param ← gap 2 MarkStale → broadcasts to ALL workspace chats ``` ### Data flow after fix ``` chatd → sets Coder-Chat-Id header on agent conn agent → extracts chatID, stores on process struct agent → sets CODER_CHAT_ID in cmd.Env gitaskpass → reads CODER_CHAT_ID, sends chat_id=<uuid> server handler → reads chat_id, passes to MarkStale MarkStale → targets only that specific chat ``` </details>
This commit is contained in:
@@ -148,6 +148,11 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
for k, v := range req.Env {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
// Propagate the chat ID so child processes (e.g.
|
||||
// GIT_ASKPASS) can send it back to the server.
|
||||
if chatID != "" {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_CHAT_ID=%s", chatID))
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
|
||||
+3
-2
@@ -66,11 +66,12 @@ const (
|
||||
maxSystemPromptLenBytes = 131072 // 128 KiB
|
||||
)
|
||||
|
||||
// chatGitRef holds the branch and remote origin reported by the
|
||||
// workspace agent during a git operation.
|
||||
// chatGitRef holds the branch, remote origin, and optional chat
|
||||
// ID reported by the workspace agent during a git operation.
|
||||
type chatGitRef struct {
|
||||
Branch string
|
||||
RemoteOrigin string
|
||||
ChatID uuid.UUID
|
||||
}
|
||||
|
||||
type chatRepositoryRef struct {
|
||||
|
||||
@@ -42,6 +42,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
maputil "github.com/coder/coder/v2/coderd/util/maps"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -1840,6 +1841,11 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
Branch: strings.TrimSpace(query.Get("git_branch")),
|
||||
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
|
||||
}
|
||||
if raw := strings.TrimSpace(query.Get("chat_id")); raw != "" {
|
||||
if parsed, err := uuid.Parse(raw); err == nil {
|
||||
gitRef.ChatID = parsed
|
||||
}
|
||||
}
|
||||
// Either match or configID must be provided!
|
||||
match := query.Get("match")
|
||||
if match == "" {
|
||||
@@ -1938,7 +1944,13 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
// context is retained even if the flow requires an out-of-band login.
|
||||
if gitRef.Branch != "" && gitRef.RemoteOrigin != "" {
|
||||
//nolint:gocritic // Chat processor context required for cross-user chat lookup
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
OwnerID: workspace.OwnerID,
|
||||
Branch: gitRef.Branch,
|
||||
Origin: gitRef.RemoteOrigin,
|
||||
ChatID: gitRef.ChatID,
|
||||
})
|
||||
}
|
||||
|
||||
var previousToken *database.ExternalAuthLink
|
||||
@@ -2087,7 +2099,13 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
|
||||
}
|
||||
// MarkStale will trigger a refresh by coderd/gitsync.
|
||||
//nolint:gocritic // Chat processor context required for cross-user chat lookup
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
OwnerID: workspace.OwnerID,
|
||||
Branch: gitRef.Branch,
|
||||
Origin: gitRef.RemoteOrigin,
|
||||
ChatID: gitRef.ChatID,
|
||||
})
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
+63
-35
@@ -274,25 +274,44 @@ func (w *Worker) tick(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// MarkStale persists the git ref on all chats for a workspace,
|
||||
// setting stale_at to the past so the next tick picks them up.
|
||||
// Publishes a diff status event for each affected chat.
|
||||
// MarkStaleParams holds the arguments for Worker.MarkStale.
|
||||
type MarkStaleParams struct {
|
||||
WorkspaceID uuid.UUID
|
||||
OwnerID uuid.UUID
|
||||
Branch string
|
||||
Origin string
|
||||
// ChatID, when set, targets a single chat instead of
|
||||
// broadcasting to every chat on the workspace.
|
||||
ChatID uuid.UUID
|
||||
}
|
||||
|
||||
// MarkStale persists the git ref for a chat (or all chats on a
|
||||
// workspace when no ChatID is provided), setting stale_at to the
|
||||
// past so the next tick picks them up. Publishes a diff status
|
||||
// event for each affected chat.
|
||||
// Called from workspaceagents handlers. No goroutines spawned.
|
||||
func (w *Worker) MarkStale(
|
||||
ctx context.Context,
|
||||
workspaceID, ownerID uuid.UUID,
|
||||
branch, origin string,
|
||||
) {
|
||||
if branch == "" || origin == "" {
|
||||
func (w *Worker) MarkStale(ctx context.Context, p MarkStaleParams) {
|
||||
if p.Branch == "" || p.Origin == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// When a specific chat is identified, target it directly
|
||||
// instead of broadcasting to every chat on the workspace.
|
||||
// Note: this path does not verify that the chat belongs to
|
||||
// WorkspaceID. This is safe because ChatID originates from
|
||||
// chatd via the agent (trusted data flow), but differs from
|
||||
// the broadcast path which filters by workspace.
|
||||
if p.ChatID != uuid.Nil {
|
||||
w.markStaleSingle(ctx, p.ChatID, p.Branch, p.Origin)
|
||||
return
|
||||
}
|
||||
|
||||
chatRows, err := w.store.GetChats(ctx, database.GetChatsParams{
|
||||
OwnerID: ownerID,
|
||||
OwnerID: p.OwnerID,
|
||||
})
|
||||
if err != nil {
|
||||
w.logger.Warn(ctx, "list chats for git ref storage",
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.F("workspace_id", p.WorkspaceID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
@@ -302,30 +321,39 @@ func (w *Worker) MarkStale(
|
||||
chats[i] = row.Chat
|
||||
}
|
||||
|
||||
for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) {
|
||||
_, err := w.store.UpsertChatDiffStatusReference(ctx,
|
||||
database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chat.ID,
|
||||
GitBranch: branch,
|
||||
GitRemoteOrigin: origin,
|
||||
StaleAt: w.clock.Now().Add(-time.Second),
|
||||
Url: sql.NullString{},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
w.logger.Warn(ctx, "store git ref on chat diff status",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err))
|
||||
continue
|
||||
}
|
||||
// Notify the frontend immediately so the UI shows the
|
||||
// branch info even before the worker refreshes PR data.
|
||||
if w.publishDiffStatusChangeFn != nil {
|
||||
if pubErr := w.publishDiffStatusChangeFn(ctx, chat.ID); pubErr != nil {
|
||||
w.logger.Debug(ctx, "publish diff status after mark stale",
|
||||
slog.F("chat_id", chat.ID), slog.Error(pubErr))
|
||||
}
|
||||
for _, chat := range filterChatsByWorkspaceID(chats, p.WorkspaceID) {
|
||||
w.markStaleSingle(ctx, chat.ID, p.Branch, p.Origin)
|
||||
}
|
||||
}
|
||||
|
||||
// markStaleSingle upserts the git ref for a single chat and
|
||||
// publishes a diff-status change event.
|
||||
func (w *Worker) markStaleSingle(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
branch, origin string,
|
||||
) {
|
||||
_, err := w.store.UpsertChatDiffStatusReference(ctx,
|
||||
database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chatID,
|
||||
GitBranch: branch,
|
||||
GitRemoteOrigin: origin,
|
||||
StaleAt: w.clock.Now().Add(-time.Second),
|
||||
Url: sql.NullString{},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
w.logger.Warn(ctx, "store git ref on chat diff status",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
// Notify the frontend immediately so the UI shows the
|
||||
// branch info even before the worker refreshes PR data.
|
||||
if w.publishDiffStatusChangeFn != nil {
|
||||
if pubErr := w.publishDiffStatusChangeFn(ctx, chatID); pubErr != nil {
|
||||
w.logger.Debug(ctx, "publish diff status after mark stale",
|
||||
slog.F("chat_id", chatID), slog.Error(pubErr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -644,7 +644,12 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, workspaceID, ownerID, "feature", "https://github.com/owner/repo")
|
||||
worker.MarkStale(ctx, gitsync.MarkStaleParams{
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
Branch: "feature",
|
||||
Origin: "https://github.com/owner/repo",
|
||||
})
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
@@ -683,7 +688,12 @@ func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, workspaceID, ownerID, "main", "https://github.com/x/y")
|
||||
worker.MarkStale(ctx, gitsync.MarkStaleParams{
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
Branch: "main",
|
||||
Origin: "https://github.com/x/y",
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
|
||||
@@ -723,7 +733,12 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, workspaceID, ownerID, "dev", "https://github.com/a/b")
|
||||
worker.MarkStale(ctx, gitsync.MarkStaleParams{
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
Branch: "dev",
|
||||
Origin: "https://github.com/a/b",
|
||||
})
|
||||
|
||||
assert.Equal(t, int32(1), publishCount.Load())
|
||||
}
|
||||
@@ -743,7 +758,12 @@ func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, uuid.New(), uuid.New(), "main", "https://github.com/x/y")
|
||||
worker.MarkStale(ctx, gitsync.MarkStaleParams{
|
||||
WorkspaceID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
Branch: "main",
|
||||
Origin: "https://github.com/x/y",
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorker_TickStoreError(t *testing.T) {
|
||||
@@ -795,11 +815,135 @@ func TestWorker_MarkStale_EmptyBranchOrOrigin(t *testing.T) {
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, uuid.New(), uuid.New(), tc.branch, tc.origin)
|
||||
worker.MarkStale(ctx, gitsync.MarkStaleParams{
|
||||
WorkspaceID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
Branch: tc.branch,
|
||||
Origin: tc.origin,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_MarkStale_WithChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
targetChat := uuid.New()
|
||||
|
||||
var mu sync.Mutex
|
||||
var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams
|
||||
var publishedIDs []uuid.UUID
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
// GetChats should NOT be called when a specific chat ID is provided.
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).Times(0)
|
||||
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
mu.Lock()
|
||||
upsertRefCalls = append(upsertRefCalls, arg)
|
||||
mu.Unlock()
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(1)
|
||||
|
||||
pub := func(_ context.Context, chatID uuid.UUID) error {
|
||||
mu.Lock()
|
||||
publishedIDs = append(publishedIDs, chatID)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
now := mClock.Now()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, gitsync.MarkStaleParams{
|
||||
WorkspaceID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
Branch: "my-branch",
|
||||
Origin: "https://github.com/org/repo",
|
||||
ChatID: targetChat,
|
||||
})
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
require.Len(t, upsertRefCalls, 1)
|
||||
assert.Equal(t, targetChat, upsertRefCalls[0].ChatID)
|
||||
assert.Equal(t, "my-branch", upsertRefCalls[0].GitBranch)
|
||||
assert.Equal(t, "https://github.com/org/repo", upsertRefCalls[0].GitRemoteOrigin)
|
||||
assert.True(t, upsertRefCalls[0].StaleAt.Before(now),
|
||||
"stale_at should be in the past, got %v vs now %v", upsertRefCalls[0].StaleAt, now)
|
||||
|
||||
require.Len(t, publishedIDs, 1)
|
||||
assert.Equal(t, targetChat, publishedIDs[0])
|
||||
}
|
||||
|
||||
func TestWorker_MarkStale_NilChatID_Broadcasts(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
chat1 := uuid.New()
|
||||
|
||||
var mu sync.Mutex
|
||||
var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams
|
||||
var publishedIDs []uuid.UUID
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
// GetChats IS called because a nil ChatID triggers the
|
||||
// workspace-wide broadcast path.
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.GetChatsRow, error) {
|
||||
require.Equal(t, ownerID, arg.OwnerID)
|
||||
return []database.GetChatsRow{
|
||||
{Chat: database.Chat{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}},
|
||||
}, nil
|
||||
})
|
||||
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
mu.Lock()
|
||||
upsertRefCalls = append(upsertRefCalls, arg)
|
||||
mu.Unlock()
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(1)
|
||||
|
||||
pub := func(_ context.Context, chatID uuid.UUID) error {
|
||||
mu.Lock()
|
||||
publishedIDs = append(publishedIDs, chatID)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
// Zero-value ChatID (uuid.Nil) triggers broadcast.
|
||||
worker.MarkStale(ctx, gitsync.MarkStaleParams{
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
Branch: "main",
|
||||
Origin: "https://github.com/org/repo",
|
||||
})
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
require.Len(t, upsertRefCalls, 1)
|
||||
assert.Equal(t, chat1, upsertRefCalls[0].ChatID)
|
||||
assert.Equal(t, "main", upsertRefCalls[0].GitBranch)
|
||||
|
||||
require.Len(t, publishedIDs, 1)
|
||||
assert.Equal(t, chat1, publishedIDs[0])
|
||||
}
|
||||
|
||||
// TestWorker exercises the worker tick against a
|
||||
// real PostgreSQL database to verify that the SQL queries, foreign key
|
||||
// constraints, and upsert logic work end-to-end.
|
||||
|
||||
Reference in New Issue
Block a user