Compare commits

..

2 Commits

Author SHA1 Message Date
Cian Johnston cb67070ff7 merge: resolve conflicts with main (preserve non-blocking workspace creation) 2026-04-10 08:15:22 +00:00
Cian Johnston 48c00e4fde feat(coderd/x/chatd): make workspace creation async so the agent stops twiddling its thumbs
create_workspace and start_workspace now return immediately instead of
blocking for up to 22 minutes (10min build + 2min agent + 10min scripts).

The wait logic moves to getWorkspaceConn via a new waitForWorkspaceReady
method. When workspace-dependent tools (execute, read_file, write_file,
etc.) actually need the workspace, they transparently wait for the build
to finish. Concurrent callers share a single poll loop via sync.Once.

This lets the LLM continue working while the workspace builds — reading
GitHub repos, searching the web, planning its approach — so the user
never notices the build time unless they inspect tool calls.

Key changes:
- Extract WaitForBuild/WaitForAgentReady to workspacereadiness.go
- create_workspace returns {status: "building"} immediately
- start_workspace returns {status: "starting"} immediately
- checkExistingWorkspace returns immediately for building workspaces
- getWorkspaceConn calls waitForWorkspaceReady before dialing
- waitForWorkspaceReady fast-paths when agents already exist
- System prompt updated to tell model about async behavior
2026-03-30 10:10:36 +00:00
100 changed files with 1590 additions and 3443 deletions
+7 -16
View File
@@ -134,19 +134,10 @@ jobs:
exit 0
fi
NEW_PR_URL=$(
gh pr create \
--base "$RELEASE_BRANCH" \
--head "$BACKPORT_BRANCH" \
--title "$TITLE" \
--body "$BODY" \
--assignee "$SENDER" \
--reviewer "$SENDER"
)
# Comment on the original PR to notify the author.
COMMENT="Cherry-pick PR created: ${NEW_PR_URL}"
if [ "$CONFLICT" = true ]; then
COMMENT="${COMMENT} (⚠️ conflicts need manual resolution)"
fi
gh pr comment "$PR_NUMBER" --body "$COMMENT"
gh pr create \
--base "$RELEASE_BRANCH" \
--head "$BACKPORT_BRANCH" \
--title "$TITLE" \
--body "$BODY" \
--assignee "$SENDER" \
--reviewer "$SENDER"
-120
View File
@@ -2862,126 +2862,6 @@ func TestAPI(t *testing.T) {
"rebuilt agent should include updated display apps")
})
// Verify that when a terraform-managed subagent is injected into
// a devcontainer, the Directory field sent to Create reflects
// the container-internal workspaceFolder from devcontainer
// read-configuration, not the host-side workspace_folder from
// the terraform resource. This is the scenario described in
// https://linear.app/codercom/issue/PRODUCT-259:
// 1. Non-terraform subagent → directory = /workspaces/foo (correct)
// 2. Terraform subagent → directory was stuck on host path (bug)
t.Run("TerraformDefinedSubAgentUsesContainerInternalDirectory", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("Dev Container tests are not supported on Windows (this test uses mocks but fails due to Windows paths)")
}
var (
ctx = testutil.Context(t, testutil.WaitMedium)
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
mCtrl = gomock.NewController(t)
terraformAgentID = uuid.New()
containerID = "test-container-id"
// Given: A container with a host-side workspace folder.
terraformContainer = codersdk.WorkspaceAgentContainer{
ID: containerID,
FriendlyName: "test-container",
Image: "test-image",
Running: true,
CreatedAt: time.Now(),
Labels: map[string]string{
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project",
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json",
},
}
// Given: A terraform-defined devcontainer whose
// workspace_folder is the HOST-side path (set by provisioner).
terraformDevcontainer = codersdk.WorkspaceAgentDevcontainer{
ID: uuid.New(),
Name: "terraform-devcontainer",
WorkspaceFolder: "/home/coder/project",
ConfigPath: "/home/coder/project/.devcontainer/devcontainer.json",
SubagentID: uuid.NullUUID{UUID: terraformAgentID, Valid: true},
}
fCCLI = &fakeContainerCLI{
containers: codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{terraformContainer},
},
arch: runtime.GOARCH,
}
// Given: devcontainer read-configuration returns the
// CONTAINER-INTERNAL workspace folder.
fDCCLI = &fakeDevcontainerCLI{
upID: containerID,
readConfig: agentcontainers.DevcontainerConfig{
Workspace: agentcontainers.DevcontainerWorkspace{
WorkspaceFolder: "/workspaces/project",
},
MergedConfiguration: agentcontainers.DevcontainerMergedConfiguration{
Customizations: agentcontainers.DevcontainerMergedCustomizations{
Coder: []agentcontainers.CoderCustomization{{}},
},
},
},
}
mSAC = acmock.NewMockSubAgentClient(mCtrl)
createCalls = make(chan agentcontainers.SubAgent, 1)
closed bool
)
mSAC.EXPECT().List(gomock.Any()).Return([]agentcontainers.SubAgent{}, nil).AnyTimes()
mSAC.EXPECT().Create(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, agent agentcontainers.SubAgent) (agentcontainers.SubAgent, error) {
agent.AuthToken = uuid.New()
createCalls <- agent
return agent, nil
},
).Times(1)
mSAC.EXPECT().Delete(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
assert.True(t, closed, "Delete should only be called after Close")
return nil
}).AnyTimes()
api := agentcontainers.NewAPI(logger,
agentcontainers.WithContainerCLI(fCCLI),
agentcontainers.WithDevcontainerCLI(fDCCLI),
agentcontainers.WithDevcontainers(
[]codersdk.WorkspaceAgentDevcontainer{terraformDevcontainer},
[]codersdk.WorkspaceAgentScript{{ID: terraformDevcontainer.ID, LogSourceID: uuid.New()}},
),
agentcontainers.WithSubAgentClient(mSAC),
agentcontainers.WithSubAgentURL("test-subagent-url"),
agentcontainers.WithWatcher(watcher.NewNoop()),
)
api.Start()
defer func() {
closed = true
api.Close()
}()
// When: The devcontainer is created (triggering injection).
err := api.CreateDevcontainer(terraformDevcontainer.WorkspaceFolder, terraformDevcontainer.ConfigPath)
require.NoError(t, err)
// Then: The subagent sent to Create has the correct
// container-internal directory, not the host path.
createdAgent := testutil.RequireReceive(ctx, t, createCalls)
assert.Equal(t, terraformAgentID, createdAgent.ID,
"agent should use terraform-defined ID")
assert.Equal(t, "/workspaces/project", createdAgent.Directory,
"directory should be the container-internal path from devcontainer "+
"read-configuration, not the host-side workspace_folder")
})
t.Run("Error", func(t *testing.T) {
t.Parallel()
+1 -1
View File
@@ -11,7 +11,7 @@ OPTIONS:
-O, --org string, $CODER_ORGANIZATION
Select which organization (uuid or name) to use.
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|workspace build transition|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
Columns to display in table output.
-i, --initiator string, $CODER_PROVISIONER_JOB_LIST_INITIATOR
@@ -58,8 +58,7 @@
"template_display_name": "",
"template_icon": "",
"workspace_id": "===========[workspace ID]===========",
"workspace_name": "test-workspace",
"workspace_build_transition": "start"
"workspace_name": "test-workspace"
},
"logs_overflowed": false,
"organization_name": "Coder"
+1 -11
View File
@@ -71,7 +71,7 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
// An ID is only given in the request when it is a terraform-defined devcontainer
// that has attached resources. These subagents are pre-provisioned by terraform
// (the agent record already exists), so we update configurable fields like
// display_apps and directory rather than creating a new agent.
// display_apps rather than creating a new agent.
if req.Id != nil {
id, err := uuid.FromBytes(req.Id)
if err != nil {
@@ -97,16 +97,6 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
return nil, xerrors.Errorf("update workspace agent display apps: %w", err)
}
if req.Directory != "" {
if err := a.Database.UpdateWorkspaceAgentDirectoryByID(ctx, database.UpdateWorkspaceAgentDirectoryByIDParams{
ID: id,
Directory: req.Directory,
UpdatedAt: createdAt,
}); err != nil {
return nil, xerrors.Errorf("update workspace agent directory: %w", err)
}
}
return &agentproto.CreateSubAgentResponse{
Agent: &agentproto.SubAgent{
Name: subAgent.Name,
+2 -38
View File
@@ -1267,11 +1267,11 @@ func TestSubAgentAPI(t *testing.T) {
agentID, err := uuid.FromBytes(resp.Agent.Id)
require.NoError(t, err)
// And: The database agent's name, architecture, and OS are unchanged.
// And: The database agent's other fields are unchanged.
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
require.NoError(t, err)
require.Equal(t, baseChildAgent.Name, updatedAgent.Name)
require.Equal(t, "/different/path", updatedAgent.Directory)
require.Equal(t, baseChildAgent.Directory, updatedAgent.Directory)
require.Equal(t, baseChildAgent.Architecture, updatedAgent.Architecture)
require.Equal(t, baseChildAgent.OperatingSystem, updatedAgent.OperatingSystem)
@@ -1280,42 +1280,6 @@ func TestSubAgentAPI(t *testing.T) {
require.Equal(t, database.DisplayAppWebTerminal, updatedAgent.DisplayApps[0])
},
},
{
name: "OK_DirectoryUpdated",
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
// Given: An existing child agent with a stale host-side
// directory (as set by the provisioner at build time).
childAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ParentID: uuid.NullUUID{Valid: true, UUID: agent.ID},
ResourceID: agent.ResourceID,
Name: baseChildAgent.Name,
Directory: "/home/coder/project",
Architecture: baseChildAgent.Architecture,
OperatingSystem: baseChildAgent.OperatingSystem,
DisplayApps: baseChildAgent.DisplayApps,
})
// When: Agent injection sends the correct
// container-internal path.
return &proto.CreateSubAgentRequest{
Id: childAgent.ID[:],
Directory: "/workspaces/project",
DisplayApps: []proto.CreateSubAgentRequest_DisplayApp{
proto.CreateSubAgentRequest_WEB_TERMINAL,
},
}
},
check: func(t *testing.T, ctx context.Context, db database.Store, resp *proto.CreateSubAgentResponse, agent database.WorkspaceAgent) {
agentID, err := uuid.FromBytes(resp.Agent.Id)
require.NoError(t, err)
// Then: Directory is updated to the container-internal
// path.
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
require.NoError(t, err)
require.Equal(t, "/workspaces/project", updatedAgent.Directory)
},
},
{
name: "Error/MalformedID",
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
-3
View File
@@ -19149,9 +19149,6 @@ const docTemplate = `{
"template_version_name": {
"type": "string"
},
"workspace_build_transition": {
"$ref": "#/definitions/codersdk.WorkspaceTransition"
},
"workspace_id": {
"type": "string",
"format": "uuid"
-3
View File
@@ -17509,9 +17509,6 @@
"template_version_name": {
"type": "string"
},
"workspace_build_transition": {
"$ref": "#/definitions/codersdk.WorkspaceTransition"
},
"workspace_id": {
"type": "string",
"format": "uuid"
+2 -15
View File
@@ -3401,11 +3401,11 @@ func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRI
return q.db.GetPRInsightsPerModel(ctx, arg)
}
func (q *querier) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsPullRequests(ctx, arg)
return q.db.GetPRInsightsRecentPRs(ctx, arg)
}
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
@@ -6783,19 +6783,6 @@ func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg da
return q.db.UpdateWorkspaceAgentConnectionByID(ctx, arg)
}
func (q *querier) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdateAgent, workspace); err != nil {
return err
}
return q.db.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
}
func (q *querier) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
if err != nil {
+3 -14
View File
@@ -2261,9 +2261,9 @@ func (s *MethodTestSuite) TestTemplate() {
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsPullRequests", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsPullRequestsParams{}
dbm.EXPECT().GetPRInsightsPullRequests(gomock.Any(), arg).Return([]database.GetPRInsightsPullRequestsRow{}, nil).AnyTimes()
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsRecentPRsParams{}
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
@@ -2935,17 +2935,6 @@ func (s *MethodTestSuite) TestWorkspace() {
dbm.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(w, policy.ActionUpdate).Returns()
}))
s.Run("UpdateWorkspaceAgentDirectoryByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
arg := database.UpdateWorkspaceAgentDirectoryByIDParams{
ID: agt.ID,
Directory: "/workspaces/project",
}
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
dbm.EXPECT().UpdateWorkspaceAgentDirectoryByID(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(w, policy.ActionUpdateAgent).Returns()
}))
s.Run("UpdateWorkspaceAgentDisplayAppsByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
+4 -12
View File
@@ -1992,11 +1992,11 @@ func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg databa
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsPullRequests(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsPullRequests").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPullRequests").Inc()
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
return r0, r1
}
@@ -4840,14 +4840,6 @@ func (m queryMetricsStore) UpdateWorkspaceAgentConnectionByID(ctx context.Contex
return r0
}
func (m queryMetricsStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
start := time.Now()
r0 := m.s.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateWorkspaceAgentDirectoryByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateWorkspaceAgentDirectoryByID").Inc()
return r0
}
func (m queryMetricsStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
start := time.Now()
r0 := m.s.UpdateWorkspaceAgentDisplayAppsByID(ctx, arg)
+7 -21
View File
@@ -3692,19 +3692,19 @@ func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
}
// GetPRInsightsPullRequests mocks base method.
func (m *MockStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
// GetPRInsightsRecentPRs mocks base method.
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsPullRequests", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsPullRequestsRow)
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsPullRequests indicates an expected call of GetPRInsightsPullRequests.
func (mr *MockStoreMockRecorder) GetPRInsightsPullRequests(ctx, arg any) *gomock.Call {
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPullRequests", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPullRequests), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
}
// GetPRInsightsSummary mocks base method.
@@ -9120,20 +9120,6 @@ func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentConnectionByID(ctx, arg any
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentConnectionByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentConnectionByID), ctx, arg)
}
// UpdateWorkspaceAgentDirectoryByID mocks base method.
func (m *MockStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateWorkspaceAgentDirectoryByID", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateWorkspaceAgentDirectoryByID indicates an expected call of UpdateWorkspaceAgentDirectoryByID.
func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentDirectoryByID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentDirectoryByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentDirectoryByID), ctx, arg)
}
// UpdateWorkspaceAgentDisplayAppsByID mocks base method.
func (m *MockStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
m.ctrl.T.Helper()
+2
View File
@@ -3791,6 +3791,8 @@ CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_con
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
CREATE INDEX idx_chats_owner_updated_id ON chats USING btree (owner_id, updated_at DESC, id DESC);
CREATE INDEX idx_chats_parent_chat_id ON chats USING btree (parent_chat_id);
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
@@ -1 +0,0 @@
CREATE INDEX idx_chats_owner_updated_id ON chats (owner_id, updated_at DESC, id DESC);
@@ -1,5 +0,0 @@
-- The GetChats ORDER BY changed from (updated_at, id) DESC to a 4-column
-- expression sort (pinned-first flag, negated pin_order, updated_at, id).
-- This index was purpose-built for the old sort and no longer provides
-- read benefit. The simpler idx_chats_owner covers the owner_id filter.
DROP INDEX IF EXISTS idx_chats_owner_updated_id;
+3 -5
View File
@@ -418,12 +418,11 @@ type sqlcQuerier interface {
// per PR for state/additions/deletions/model (model comes from the
// most recent chat).
GetPRInsightsPerModel(ctx context.Context, arg GetPRInsightsPerModelParams) ([]GetPRInsightsPerModelRow, error)
// Returns all individual PR rows with cost for the selected time range.
// Returns individual PR rows with cost for the recent PRs table.
// Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
// direct children (that lack their own PR), and deduped picks one row
// per PR for metadata. A safety-cap LIMIT guards against unexpectedly
// large result sets from direct API callers.
GetPRInsightsPullRequests(ctx context.Context, arg GetPRInsightsPullRequestsParams) ([]GetPRInsightsPullRequestsRow, error)
// per PR for metadata.
GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error)
// PR Insights queries for the /agents analytics dashboard.
// These aggregate data from chat_diff_statuses (PR metadata) joined
// with chats and chat_messages (cost) to power the PR Insights view.
@@ -1012,7 +1011,6 @@ type sqlcQuerier interface {
UpdateWorkspace(ctx context.Context, arg UpdateWorkspaceParams) (WorkspaceTable, error)
UpdateWorkspaceACLByID(ctx context.Context, arg UpdateWorkspaceACLByIDParams) error
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg UpdateWorkspaceAgentDirectoryByIDParams) error
UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg UpdateWorkspaceAgentDisplayAppsByIDParams) error
UpdateWorkspaceAgentLifecycleStateByID(ctx context.Context, arg UpdateWorkspaceAgentLifecycleStateByIDParams) error
UpdateWorkspaceAgentLogOverflowByID(ctx context.Context, arg UpdateWorkspaceAgentLogOverflowByIDParams) error
+20 -34
View File
@@ -10408,10 +10408,11 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(1), summary.TotalPrsCreated)
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 1)
@@ -10441,10 +10442,11 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(1), summary.TotalPrsMerged)
// RecentPRs ordered by created_at DESC: chatB is newer.
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 2)
@@ -10489,10 +10491,11 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(1), summary.TotalPrsCreated)
assert.Equal(t, int64(1), summary.TotalPrsMerged)
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 1)
@@ -10530,10 +10533,11 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(9_000_000), summary.TotalCostMicros)
// RecentPRs should return 1 row with the full tree cost.
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 1)
@@ -10571,10 +10575,11 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(2), summary.TotalPrsCreated)
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 2)
@@ -10616,10 +10621,11 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(2), summary.TotalPrsCreated)
assert.Equal(t, int64(17_000_000), summary.TotalCostMicros)
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 2)
@@ -10652,10 +10658,11 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(2), summary.TotalPrsCreated)
assert.Equal(t, int64(10_000_000), summary.TotalCostMicros)
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 2)
@@ -10688,10 +10695,11 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(1), summary.TotalPrsCreated)
assert.Equal(t, int64(15_000_000), summary.TotalCostMicros)
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 1)
@@ -10716,10 +10724,11 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(1), summary.TotalPrsCreated)
assert.Equal(t, int64(0), summary.TotalCostMicros)
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 1)
@@ -10758,10 +10767,11 @@ func TestGetPRInsights(t *testing.T) {
require.Len(t, byModel, 1)
assert.Equal(t, modelName, byModel[0].DisplayName)
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 1)
@@ -10793,30 +10803,6 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
assert.Equal(t, int64(5_000_000), summary.MergedCostMicros)
})
t.Run("AllPRsReturnedWithSafetyCap", func(t *testing.T) {
t.Parallel()
store, userID, mcID := setupChatInfra(t)
// Create 25 distinct PRs — more than the old LIMIT 20 — and
// verify all are returned.
const prCount = 25
for i := range prCount {
chat := createChat(t, store, userID, mcID, fmt.Sprintf("chat-%d", i))
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
linkPR(t, store, chat.ID,
fmt.Sprintf("https://github.com/org/repo/pull/%d", 100+i),
"merged", fmt.Sprintf("fix: pr-%d", i), 10, 2, 1)
}
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
})
require.NoError(t, err)
assert.Len(t, recent, prCount, "all PRs within the date range should be returned")
})
}
func TestChatPinOrderQueries(t *testing.T) {
+49 -72
View File
@@ -3218,7 +3218,7 @@ func (q *sqlQuerier) GetPRInsightsPerModel(ctx context.Context, arg GetPRInsight
return items, nil
}
const getPRInsightsPullRequests = `-- name: GetPRInsightsPullRequests :many
const getPRInsightsRecentPRs = `-- name: GetPRInsightsRecentPRs :many
WITH pr_costs AS (
SELECT
prc.pr_key,
@@ -3238,9 +3238,9 @@ WITH pr_costs AS (
AND cds2.pull_request_state IS NOT NULL
))
WHERE cds.pull_request_state IS NOT NULL
AND c.created_at >= $1::timestamptz
AND c.created_at < $2::timestamptz
AND ($3::uuid IS NULL OR c.owner_id = $3::uuid)
AND c.created_at >= $2::timestamptz
AND c.created_at < $3::timestamptz
AND ($4::uuid IS NULL OR c.owner_id = $4::uuid)
) prc
LEFT JOIN LATERAL (
SELECT COALESCE(SUM(cm.total_cost_micros), 0) AS cost_micros
@@ -3275,9 +3275,9 @@ deduped AS (
JOIN chats c ON c.id = cds.chat_id
LEFT JOIN chat_model_configs cmc ON cmc.id = c.last_model_config_id
WHERE cds.pull_request_state IS NOT NULL
AND c.created_at >= $1::timestamptz
AND c.created_at < $2::timestamptz
AND ($3::uuid IS NULL OR c.owner_id = $3::uuid)
AND c.created_at >= $2::timestamptz
AND c.created_at < $3::timestamptz
AND ($4::uuid IS NULL OR c.owner_id = $4::uuid)
ORDER BY COALESCE(NULLIF(cds.url, ''), c.id::text), c.created_at DESC, c.id DESC
)
SELECT chat_id, pr_title, pr_url, pr_number, state, draft, additions, deletions, changed_files, commits, approved, changes_requested, reviewer_count, author_login, author_avatar_url, base_branch, model_display_name, cost_micros, created_at FROM (
@@ -3305,16 +3305,17 @@ SELECT chat_id, pr_title, pr_url, pr_number, state, draft, additions, deletions,
JOIN pr_costs pc ON pc.pr_key = d.pr_key
) sub
ORDER BY sub.created_at DESC
LIMIT 500
LIMIT $1::int
`
type GetPRInsightsPullRequestsParams struct {
type GetPRInsightsRecentPRsParams struct {
LimitVal int32 `db:"limit_val" json:"limit_val"`
StartDate time.Time `db:"start_date" json:"start_date"`
EndDate time.Time `db:"end_date" json:"end_date"`
OwnerID uuid.NullUUID `db:"owner_id" json:"owner_id"`
}
type GetPRInsightsPullRequestsRow struct {
type GetPRInsightsRecentPRsRow struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
PrTitle string `db:"pr_title" json:"pr_title"`
PrUrl sql.NullString `db:"pr_url" json:"pr_url"`
@@ -3336,20 +3337,24 @@ type GetPRInsightsPullRequestsRow struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
// Returns all individual PR rows with cost for the selected time range.
// Returns individual PR rows with cost for the recent PRs table.
// Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
// direct children (that lack their own PR), and deduped picks one row
// per PR for metadata. A safety-cap LIMIT guards against unexpectedly
// large result sets from direct API callers.
func (q *sqlQuerier) GetPRInsightsPullRequests(ctx context.Context, arg GetPRInsightsPullRequestsParams) ([]GetPRInsightsPullRequestsRow, error) {
rows, err := q.db.QueryContext(ctx, getPRInsightsPullRequests, arg.StartDate, arg.EndDate, arg.OwnerID)
// per PR for metadata.
func (q *sqlQuerier) GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error) {
rows, err := q.db.QueryContext(ctx, getPRInsightsRecentPRs,
arg.LimitVal,
arg.StartDate,
arg.EndDate,
arg.OwnerID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetPRInsightsPullRequestsRow
var items []GetPRInsightsRecentPRsRow
for rows.Next() {
var i GetPRInsightsPullRequestsRow
var i GetPRInsightsRecentPRsRow
if err := rows.Scan(
&i.ChatID,
&i.PrTitle,
@@ -5818,18 +5823,20 @@ WHERE
ELSE chats.archived = $2 :: boolean
END
AND CASE
-- Cursor pagination: the last element on a page acts as the cursor.
-- The 4-tuple matches the ORDER BY below. All columns sort DESC
-- (pin_order is negated so lower values sort first in DESC order),
-- which lets us use a single tuple < comparison.
-- This allows using the last element on a page as effectively a cursor.
-- This is an important option for scripts that need to paginate without
-- duplicating or missing data.
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
(CASE WHEN pin_order > 0 THEN 1 ELSE 0 END, -pin_order, updated_at, id) < (
-- The pagination cursor is the last ID of the previous page.
-- The query is ordered by the updated_at field, so select all
-- rows before the cursor.
(updated_at, id) < (
SELECT
CASE WHEN c2.pin_order > 0 THEN 1 ELSE 0 END, -c2.pin_order, c2.updated_at, c2.id
updated_at, id
FROM
chats c2
chats
WHERE
c2.id = $3
id = $3
)
)
ELSE true
@@ -5841,15 +5848,9 @@ WHERE
-- Authorize Filter clause will be injected below in GetAuthorizedChats
-- @authorize_filter
ORDER BY
-- Pinned chats (pin_order > 0) sort before unpinned ones. Within
-- pinned chats, lower pin_order values come first. The negation
-- trick (-pin_order) keeps all sort columns DESC so the cursor
-- tuple < comparison works with uniform direction.
CASE WHEN pin_order > 0 THEN 1 ELSE 0 END DESC,
-pin_order DESC,
updated_at DESC,
id DESC
OFFSET $5
-- Deterministic and consistent ordering of all rows, even if they share
-- a timestamp. This is to ensure consistent pagination.
(updated_at, id) DESC OFFSET $5
LIMIT
-- The chat list is unbounded and expected to grow large.
-- Default to 50 to prevent accidental excessively large queries.
@@ -17518,8 +17519,7 @@ SELECT
w.id AS workspace_id,
COALESCE(w.name, '') AS workspace_name,
-- Include the name of the provisioner_daemon associated to the job
COALESCE(pd.name, '') AS worker_name,
wb.transition as workspace_build_transition
COALESCE(pd.name, '') AS worker_name
FROM
provisioner_jobs pj
LEFT JOIN
@@ -17564,8 +17564,7 @@ GROUP BY
t.icon,
w.id,
w.name,
pd.name,
wb.transition
pd.name
ORDER BY
pj.created_at DESC
LIMIT
@@ -17582,19 +17581,18 @@ type GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerPar
}
type GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow struct {
ProvisionerJob ProvisionerJob `db:"provisioner_job" json:"provisioner_job"`
QueuePosition int64 `db:"queue_position" json:"queue_position"`
QueueSize int64 `db:"queue_size" json:"queue_size"`
AvailableWorkers []uuid.UUID `db:"available_workers" json:"available_workers"`
TemplateVersionName string `db:"template_version_name" json:"template_version_name"`
TemplateID uuid.NullUUID `db:"template_id" json:"template_id"`
TemplateName string `db:"template_name" json:"template_name"`
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
TemplateIcon string `db:"template_icon" json:"template_icon"`
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
WorkerName string `db:"worker_name" json:"worker_name"`
WorkspaceBuildTransition NullWorkspaceTransition `db:"workspace_build_transition" json:"workspace_build_transition"`
ProvisionerJob ProvisionerJob `db:"provisioner_job" json:"provisioner_job"`
QueuePosition int64 `db:"queue_position" json:"queue_position"`
QueueSize int64 `db:"queue_size" json:"queue_size"`
AvailableWorkers []uuid.UUID `db:"available_workers" json:"available_workers"`
TemplateVersionName string `db:"template_version_name" json:"template_version_name"`
TemplateID uuid.NullUUID `db:"template_id" json:"template_id"`
TemplateName string `db:"template_name" json:"template_name"`
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
TemplateIcon string `db:"template_icon" json:"template_icon"`
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
WorkerName string `db:"worker_name" json:"worker_name"`
}
func (q *sqlQuerier) GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error) {
@@ -17646,7 +17644,6 @@ func (q *sqlQuerier) GetProvisionerJobsByOrganizationAndStatusWithQueuePositionA
&i.WorkspaceID,
&i.WorkspaceName,
&i.WorkerName,
&i.WorkspaceBuildTransition,
); err != nil {
return nil, err
}
@@ -26819,26 +26816,6 @@ func (q *sqlQuerier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg
return err
}
const updateWorkspaceAgentDirectoryByID = `-- name: UpdateWorkspaceAgentDirectoryByID :exec
UPDATE
workspace_agents
SET
directory = $2, updated_at = $3
WHERE
id = $1
`
type UpdateWorkspaceAgentDirectoryByIDParams struct {
ID uuid.UUID `db:"id" json:"id"`
Directory string `db:"directory" json:"directory"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
func (q *sqlQuerier) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg UpdateWorkspaceAgentDirectoryByIDParams) error {
_, err := q.db.ExecContext(ctx, updateWorkspaceAgentDirectoryByID, arg.ID, arg.Directory, arg.UpdatedAt)
return err
}
const updateWorkspaceAgentDisplayAppsByID = `-- name: UpdateWorkspaceAgentDisplayAppsByID :exec
UPDATE
workspace_agents
+4 -5
View File
@@ -173,12 +173,11 @@ JOIN pr_costs pc ON pc.pr_key = d.pr_key
GROUP BY d.model_config_id, d.display_name, d.model, d.provider
ORDER BY total_prs DESC;
-- name: GetPRInsightsPullRequests :many
-- Returns all individual PR rows with cost for the selected time range.
-- name: GetPRInsightsRecentPRs :many
-- Returns individual PR rows with cost for the recent PRs table.
-- Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
-- direct children (that lack their own PR), and deduped picks one row
-- per PR for metadata. A safety-cap LIMIT guards against unexpectedly
-- large result sets from direct API callers.
-- per PR for metadata.
WITH pr_costs AS (
SELECT
prc.pr_key,
@@ -265,4 +264,4 @@ SELECT * FROM (
JOIN pr_costs pc ON pc.pr_key = d.pr_key
) sub
ORDER BY sub.created_at DESC
LIMIT 500;
LIMIT @limit_val::int;
+13 -17
View File
@@ -353,18 +353,20 @@ WHERE
ELSE chats.archived = sqlc.narg('archived') :: boolean
END
AND CASE
-- Cursor pagination: the last element on a page acts as the cursor.
-- The 4-tuple matches the ORDER BY below. All columns sort DESC
-- (pin_order is negated so lower values sort first in DESC order),
-- which lets us use a single tuple < comparison.
-- This allows using the last element on a page as effectively a cursor.
-- This is an important option for scripts that need to paginate without
-- duplicating or missing data.
WHEN @after_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
(CASE WHEN pin_order > 0 THEN 1 ELSE 0 END, -pin_order, updated_at, id) < (
-- The pagination cursor is the last ID of the previous page.
-- The query is ordered by the updated_at field, so select all
-- rows before the cursor.
(updated_at, id) < (
SELECT
CASE WHEN c2.pin_order > 0 THEN 1 ELSE 0 END, -c2.pin_order, c2.updated_at, c2.id
updated_at, id
FROM
chats c2
chats
WHERE
c2.id = @after_id
id = @after_id
)
)
ELSE true
@@ -376,15 +378,9 @@ WHERE
-- Authorize Filter clause will be injected below in GetAuthorizedChats
-- @authorize_filter
ORDER BY
-- Pinned chats (pin_order > 0) sort before unpinned ones. Within
-- pinned chats, lower pin_order values come first. The negation
-- trick (-pin_order) keeps all sort columns DESC so the cursor
-- tuple < comparison works with uniform direction.
CASE WHEN pin_order > 0 THEN 1 ELSE 0 END DESC,
-pin_order DESC,
updated_at DESC,
id DESC
OFFSET @offset_opt
-- Deterministic and consistent ordering of all rows, even if they share
-- a timestamp. This is to ensure consistent pagination.
(updated_at, id) DESC OFFSET @offset_opt
LIMIT
-- The chat list is unbounded and expected to grow large.
-- Default to 50 to prevent accidental excessively large queries.
+2 -4
View File
@@ -195,8 +195,7 @@ SELECT
w.id AS workspace_id,
COALESCE(w.name, '') AS workspace_name,
-- Include the name of the provisioner_daemon associated to the job
COALESCE(pd.name, '') AS worker_name,
wb.transition as workspace_build_transition
COALESCE(pd.name, '') AS worker_name
FROM
provisioner_jobs pj
LEFT JOIN
@@ -241,8 +240,7 @@ GROUP BY
t.icon,
w.id,
w.name,
pd.name,
wb.transition
pd.name
ORDER BY
pj.created_at DESC
LIMIT
@@ -190,14 +190,6 @@ SET
WHERE
id = $1;
-- name: UpdateWorkspaceAgentDirectoryByID :exec
UPDATE
workspace_agents
SET
directory = $2, updated_at = $3
WHERE
id = $1;
-- name: GetWorkspaceAgentLogsAfter :many
SELECT
*
+77 -70
View File
@@ -137,9 +137,8 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
logger := api.Logger.Named("chat_watcher")
conn, err := websocket.Accept(rw, r, nil)
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to open chat watch stream.",
@@ -147,44 +146,54 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
})
return
}
defer func() {
<-senderClosed
}()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
_ = conn.CloseRead(context.Background())
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
defer wsNetConn.Close()
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
// The encoder is only written from the SubscribeWithErr callback,
// which delivers serially per subscription. Do not add a second
// write path without introducing synchronization.
encoder := json.NewEncoder(wsNetConn)
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
pubsub.HandleChatWatchEvent(
func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) {
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
pubsub.HandleChatEvent(
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
if err != nil {
logger.Error(ctx, "chat watch event subscription error", slog.Error(err))
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
return
}
if err := encoder.Encode(payload); err != nil {
logger.Debug(ctx, "failed to send chat watch event", slog.Error(err))
cancel()
return
if err := sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: payload,
}); err != nil {
api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err))
}
},
))
if err != nil {
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.")
if err := sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Internal error subscribing to chat events.",
Detail: err.Error(),
},
}); err != nil {
api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err))
}
return
}
defer cancelSubscribe()
<-ctx.Done()
// Send initial ping to signal the connection is ready.
if err := sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypePing,
}); err != nil {
api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err))
}
for {
select {
case <-ctx.Done():
return
case <-senderClosed:
return
}
}
}
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
@@ -1810,9 +1819,9 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
// - pinOrder > 0 && already pinned: reorder (shift
// neighbors, clamp to [1, count]).
// - pinOrder > 0 && not pinned: append to end. The
// requested value is intentionally ignored; the
// SQL ORDER BY sorts pinned chats first so they
// appear on page 1 of the paginated sidebar.
// requested value is intentionally ignored because
// PinChatByID also bumps updated_at to keep the
// chat visible in the paginated sidebar.
var err error
errMsg := "Failed to pin chat."
switch {
@@ -2167,7 +2176,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
chatID := chat.ID
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
if api.chatDaemon == nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@@ -2190,22 +2198,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
}
}
// Subscribe before accepting the WebSocket so that failures
// can still be reported as normal HTTP errors.
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
// Subscribe only fails today when the receiver is nil, which
// the chatDaemon == nil guard above already catches. This is
// defensive against future Subscribe failure modes.
if !ok {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Chat streaming is not available.",
Detail: "Chat stream state is not configured.",
})
return
}
defer cancelSub()
conn, err := websocket.Accept(rw, r, nil)
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to open chat stream.",
@@ -2213,30 +2206,41 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
})
return
}
ctx, cancel := context.WithCancel(ctx)
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
if !ok {
if err := sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Chat streaming is not available.",
Detail: "Chat stream state is not configured.",
},
}); err != nil {
api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err))
}
// Ensure the WebSocket is closed so senderClosed
// completes and the handler can return.
<-senderClosed
return
}
defer func() {
<-senderClosed
}()
defer cancel()
_ = conn.CloseRead(context.Background())
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
defer wsNetConn.Close()
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
// Mark the chat as read when the stream connects and again
// when it disconnects so we avoid per-message API calls while
// messages are actively streaming.
api.markChatAsRead(ctx, chatID)
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
encoder := json.NewEncoder(wsNetConn)
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
if len(batch) == 0 {
return nil
}
return encoder.Encode(batch)
return sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: batch,
})
}
drainChatStreamBatch := func(
@@ -2269,7 +2273,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
end = len(snapshot)
}
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
return
}
}
@@ -2278,6 +2282,8 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
select {
case <-ctx.Done():
return
case <-senderClosed:
return
case firstEvent, ok := <-events:
if !ok {
return
@@ -2287,7 +2293,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
chatStreamBatchSize,
)
if err := sendChatStreamBatch(batch); err != nil {
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
return
}
if streamClosed {
@@ -2302,7 +2308,6 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
chatID := chat.ID
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
if api.chatDaemon != nil {
chat = api.chatDaemon.InterruptChat(ctx, chat)
@@ -2316,7 +2321,8 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
LastError: sql.NullString{},
})
if updateErr != nil {
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
api.Logger.Error(ctx, "failed to mark chat as waiting",
slog.F("chat_id", chatID), slog.Error(updateErr))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to interrupt chat.",
Detail: updateErr.Error(),
@@ -5626,7 +5632,7 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
previousSummary database.GetPRInsightsSummaryRow
timeSeries []database.GetPRInsightsTimeSeriesRow
byModel []database.GetPRInsightsPerModelRow
recentPRs []database.GetPRInsightsPullRequestsRow
recentPRs []database.GetPRInsightsRecentPRsRow
)
eg, egCtx := errgroup.WithContext(ctx)
@@ -5674,10 +5680,11 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
eg.Go(func() error {
var err error
recentPRs, err = api.Database.GetPRInsightsPullRequests(egCtx, database.GetPRInsightsPullRequestsParams{
recentPRs, err = api.Database.GetPRInsightsRecentPRs(egCtx, database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: ownerID,
LimitVal: 20,
})
return err
})
@@ -5787,10 +5794,10 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PRInsightsResponse{
Summary: summary,
TimeSeries: tsEntries,
ByModel: modelEntries,
PullRequests: prEntries,
Summary: summary,
TimeSeries: tsEntries,
ByModel: modelEntries,
RecentPRs: prEntries,
})
}
+98 -199
View File
@@ -876,186 +876,6 @@ func TestListChats(t *testing.T) {
require.NoError(t, err)
require.Len(t, allChats, totalChats)
})
// Test that a pinned chat with an old updated_at appears on page 1.
t.Run("PinnedOnFirstPage", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, _ := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
// Create the chat that will later be pinned. It gets the
// earliest updated_at because it is inserted first.
pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "pinned-chat",
}},
})
require.NoError(t, err)
// Fill page 1 with newer chats so the pinned chat would
// normally be pushed off the first page (default limit 50).
const fillerCount = 51
fillerChats := make([]codersdk.Chat, 0, fillerCount)
for i := range fillerCount {
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: fmt.Sprintf("filler-%d", i),
}},
})
require.NoError(t, createErr)
fillerChats = append(fillerChats, c)
}
// Wait for all chats to reach a terminal status so
// updated_at is stable before paginating. A single
// polling loop checks every chat per tick to avoid
// O(N) separate Eventually loops.
allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...)
pending := make(map[uuid.UUID]struct{}, len(allCreated))
for _, c := range allCreated {
pending[c.ID] = struct{}{}
}
testutil.Eventually(ctx, t, func(_ context.Context) bool {
all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{
Pagination: codersdk.Pagination{Limit: fillerCount + 10},
})
if listErr != nil {
return false
}
for _, ch := range all {
if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning {
delete(pending, ch.ID)
}
}
return len(pending) == 0
}, testutil.IntervalFast)
// Pin the earliest chat.
err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{
PinOrder: ptr.Ref(int32(1)),
})
require.NoError(t, err)
// Fetch page 1 with default limit (50).
page1, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
Pagination: codersdk.Pagination{Limit: 50},
})
require.NoError(t, err)
// The pinned chat must appear on page 1.
page1IDs := make(map[uuid.UUID]struct{}, len(page1))
for _, c := range page1 {
page1IDs[c.ID] = struct{}{}
}
_, found := page1IDs[pinnedChat.ID]
require.True(t, found, "pinned chat should appear on page 1")
// The pinned chat should be the first item in the list.
require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first")
})
// Test cursor pagination with a mix of pinned and unpinned chats.
t.Run("CursorWithPins", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, _ := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
// Create 5 chats: 2 will be pinned, 3 unpinned.
const totalChats = 5
createdChats := make([]codersdk.Chat, 0, totalChats)
for i := range totalChats {
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: fmt.Sprintf("cursor-pin-chat-%d", i),
}},
})
require.NoError(t, createErr)
createdChats = append(createdChats, c)
}
// Wait for all chats to reach terminal status.
// Check each chat by ID rather than fetching the full list.
testutil.Eventually(ctx, t, func(_ context.Context) bool {
for _, c := range createdChats {
ch, err := client.GetChat(ctx, c.ID)
require.NoError(t, err, "GetChat should succeed for just-created chat %s", c.ID)
if ch.Status == codersdk.ChatStatusPending || ch.Status == codersdk.ChatStatusRunning {
return false
}
}
return true
}, testutil.IntervalFast)
// Pin the first two chats (oldest updated_at).
err := client.UpdateChat(ctx, createdChats[0].ID, codersdk.UpdateChatRequest{
PinOrder: ptr.Ref(int32(1)),
})
require.NoError(t, err)
err = client.UpdateChat(ctx, createdChats[1].ID, codersdk.UpdateChatRequest{
PinOrder: ptr.Ref(int32(1)),
})
require.NoError(t, err)
// Paginate with limit=2 using cursor (after_id).
const pageSize = 2
maxPages := totalChats/pageSize + 2
var allPaginated []codersdk.Chat
var afterID uuid.UUID
for range maxPages {
opts := &codersdk.ListChatsOptions{
Pagination: codersdk.Pagination{Limit: pageSize},
}
if afterID != uuid.Nil {
opts.Pagination.AfterID = afterID
}
page, listErr := client.ListChats(ctx, opts)
require.NoError(t, listErr)
if len(page) == 0 {
break
}
allPaginated = append(allPaginated, page...)
afterID = page[len(page)-1].ID
}
// All chats should appear exactly once.
seenIDs := make(map[uuid.UUID]struct{}, len(allPaginated))
for _, c := range allPaginated {
_, dup := seenIDs[c.ID]
require.False(t, dup, "chat %s appeared more than once", c.ID)
seenIDs[c.ID] = struct{}{}
}
require.Len(t, seenIDs, totalChats, "all chats should appear in paginated results")
// Pinned chats should come before unpinned ones, and
// within the pinned group, lower pin_order sorts first.
pinnedSeen := false
unpinnedSeen := false
for _, c := range allPaginated {
if c.PinOrder > 0 {
require.False(t, unpinnedSeen, "pinned chat %s appeared after unpinned chat", c.ID)
pinnedSeen = true
} else {
unpinnedSeen = true
}
}
require.True(t, pinnedSeen, "at least one pinned chat should exist")
// Verify within-pinned ordering: pin_order=1 before
// pin_order=2 (the -pin_order DESC column).
require.Equal(t, createdChats[0].ID, allPaginated[0].ID,
"pin_order=1 chat should be first")
require.Equal(t, createdChats[1].ID, allPaginated[1].ID,
"pin_order=2 chat should be second")
})
}
func TestListChatModels(t *testing.T) {
@@ -1294,6 +1114,17 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
var event watchEvent
err = wsjson.Read(ctx, conn, &event)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
@@ -1305,16 +1136,25 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
for {
var payload codersdk.ChatWatchEvent
err = wsjson.Read(ctx, conn, &payload)
var update watchEvent
err = wsjson.Read(ctx, conn, &update)
require.NoError(t, err)
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
if update.Type == codersdk.ServerSentEventTypePing {
continue
}
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
var payload coderdpubsub.ChatEvent
err = json.Unmarshal(update.Data, &payload)
require.NoError(t, err)
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
payload.Chat.ID == createdChat.ID {
break
}
}
})
t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) {
t.Parallel()
@@ -1334,6 +1174,18 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
// Skip the initial ping.
var event watchEvent
err = wsjson.Read(ctx, conn, &event)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
@@ -1346,11 +1198,18 @@ func TestWatchChats(t *testing.T) {
var got codersdk.Chat
testutil.Eventually(ctx, t, func(_ context.Context) bool {
var payload codersdk.ChatWatchEvent
if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil {
var update watchEvent
if readErr := wsjson.Read(ctx, conn, &update); readErr != nil {
return false
}
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
if update.Type != codersdk.ServerSentEventTypeData {
return false
}
var payload coderdpubsub.ChatEvent
if unmarshalErr := json.Unmarshal(update.Data, &payload); unmarshalErr != nil {
return false
}
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
payload.Chat.ID == createdChat.ID {
got = payload.Chat
return true
@@ -1423,14 +1282,25 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
// Read the initial ping.
var ping watchEvent
err = wsjson.Read(ctx, conn, &ping)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
// Publish a diff_status_change event via pubsub,
// mimicking what PublishDiffStatusChange does after
// it reads the diff status from the DB.
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
require.NoError(t, err)
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
event := codersdk.ChatWatchEvent{
Kind: codersdk.ChatWatchEventKindDiffStatusChange,
event := coderdpubsub.ChatEvent{
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
Chat: codersdk.Chat{
ID: chat.ID,
OwnerID: chat.OwnerID,
@@ -1443,15 +1313,25 @@ func TestWatchChats(t *testing.T) {
}
payload, err := json.Marshal(event)
require.NoError(t, err)
err = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload)
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
require.NoError(t, err)
// Read events until we find the diff_status_change.
for {
var received codersdk.ChatWatchEvent
err = wsjson.Read(ctx, conn, &received)
var update watchEvent
err = wsjson.Read(ctx, conn, &update)
require.NoError(t, err)
if received.Kind != codersdk.ChatWatchEventKindDiffStatusChange ||
if update.Type == codersdk.ServerSentEventTypePing {
continue
}
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
var received coderdpubsub.ChatEvent
err = json.Unmarshal(update.Data, &received)
require.NoError(t, err)
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
received.Chat.ID != chat.ID {
continue
}
@@ -1470,6 +1350,7 @@ func TestWatchChats(t *testing.T) {
break
}
})
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
t.Parallel()
@@ -1512,13 +1393,31 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent {
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
var ping watchEvent
err = wsjson.Read(ctx, conn, &ping)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
t.Helper()
events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3)
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
for len(events) < 3 {
var payload codersdk.ChatWatchEvent
err = wsjson.Read(ctx, conn, &payload)
var update watchEvent
err = wsjson.Read(ctx, conn, &update)
require.NoError(t, err)
if update.Type == codersdk.ServerSentEventTypePing {
continue
}
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
var payload coderdpubsub.ChatEvent
err = json.Unmarshal(update.Data, &payload)
require.NoError(t, err)
if payload.Kind != expectedKind {
continue
@@ -1528,7 +1427,7 @@ func TestWatchChats(t *testing.T) {
return events
}
assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) {
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
t.Helper()
require.Len(t, events, 3)
@@ -1541,12 +1440,12 @@ func TestWatchChats(t *testing.T) {
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted)
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
assertLifecycleEvents(deletedEvents, true)
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
require.NoError(t, err)
createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated)
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
assertLifecycleEvents(createdEvents, false)
})
+3 -27
View File
@@ -5,7 +5,6 @@ import (
"database/sql"
"encoding/hex"
"errors"
htmltemplate "html/template"
"net/http"
"net/url"
"strings"
@@ -147,35 +146,12 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
cancel := params.redirectURL
cancelQuery := params.redirectURL.Query()
cancelQuery.Add("error", "access_denied")
cancelQuery.Add("error_description", "The resource owner or authorization server denied the request")
if params.state != "" {
cancelQuery.Add("state", params.state)
}
cancel.RawQuery = cancelQuery.Encode()
cancelURI := cancel.String()
if err := codersdk.ValidateRedirectURIScheme(cancel); err != nil {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadRequest,
HideStatus: false,
Title: "Invalid Callback URL",
Description: "The application's registered callback URL has an invalid scheme.",
Actions: []site.Action{
{
URL: accessURL.String(),
Text: "Back to site",
},
},
})
return
}
site.RenderOAuthAllowPage(rw, r, site.RenderOAuthAllowData{
AppIcon: app.Icon,
AppName: app.Name,
// #nosec G203 -- The scheme is validated by
// codersdk.ValidateRedirectURIScheme above.
CancelURI: htmltemplate.URL(cancelURI),
AppIcon: app.Icon,
AppName: app.Name,
CancelURI: cancel.String(),
RedirectURI: r.URL.String(),
CSRFToken: nosurf.Token(r),
Username: ua.FriendlyName,
+1 -2
View File
@@ -1,7 +1,6 @@
package oauth2provider_test
import (
htmltemplate "html/template"
"net/http"
"net/http/httptest"
"testing"
@@ -21,7 +20,7 @@ func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
AppName: "Test OAuth App",
CancelURI: htmltemplate.URL("https://coder.com/cancel"),
CancelURI: "https://coder.com/cancel",
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
CSRFToken: csrfFieldValue,
Username: "test-user",
-3
View File
@@ -435,9 +435,6 @@ func convertProvisionerJobWithQueuePosition(pj database.GetProvisionerJobsByOrga
if pj.WorkspaceID.Valid {
job.Metadata.WorkspaceID = &pj.WorkspaceID.UUID
}
if pj.WorkspaceBuildTransition.Valid {
job.Metadata.WorkspaceBuildTransition = codersdk.WorkspaceTransition(pj.WorkspaceBuildTransition.WorkspaceTransition)
}
return job
}
+7 -8
View File
@@ -97,14 +97,13 @@ func TestProvisionerJobs(t *testing.T) {
// Verify that job metadata is correct.
assert.Equal(t, job2.Metadata, codersdk.ProvisionerJobMetadata{
TemplateVersionName: version.Name,
TemplateID: template.ID,
TemplateName: template.Name,
TemplateDisplayName: template.DisplayName,
TemplateIcon: template.Icon,
WorkspaceID: &w.ID,
WorkspaceName: w.Name,
WorkspaceBuildTransition: codersdk.WorkspaceTransitionStart,
TemplateVersionName: version.Name,
TemplateID: template.ID,
TemplateName: template.Name,
TemplateDisplayName: template.DisplayName,
TemplateIcon: template.Icon,
WorkspaceID: &w.ID,
WorkspaceName: w.Name,
})
})
})
+1 -1
View File
@@ -14,7 +14,7 @@ import (
const ChatConfigEventChannel = "chat:config_change"
// HandleChatConfigEvent wraps a typed callback for ChatConfigEvent
// messages, following the same pattern as HandleChatWatchEvent.
// messages, following the same pattern as HandleChatEvent.
func HandleChatConfigEvent(cb func(ctx context.Context, payload ChatConfigEvent, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
+49
View File
@@ -0,0 +1,49 @@
package pubsub
import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
func ChatEventChannel(ownerID uuid.UUID) string {
return fmt.Sprintf("chat:owner:%s", ownerID)
}
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
return
}
var payload ChatEvent
if err := json.Unmarshal(message, &payload); err != nil {
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event: %w", err))
return
}
cb(ctx, payload, err)
}
}
type ChatEvent struct {
Kind ChatEventKind `json:"kind"`
Chat codersdk.Chat `json:"chat"`
ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"`
}
type ChatEventKind string
const (
ChatEventKindStatusChange ChatEventKind = "status_change"
ChatEventKindTitleChange ChatEventKind = "title_change"
ChatEventKindCreated ChatEventKind = "created"
ChatEventKindDeleted ChatEventKind = "deleted"
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
ChatEventKindActionRequired ChatEventKind = "action_required"
)
-36
View File
@@ -1,36 +0,0 @@
package pubsub
import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
// ChatWatchEventChannel returns the pubsub channel for chat
// lifecycle events scoped to a single user.
func ChatWatchEventChannel(ownerID uuid.UUID) string {
return fmt.Sprintf("chat:owner:%s", ownerID)
}
// HandleChatWatchEvent wraps a typed callback for
// ChatWatchEvent messages delivered via pubsub.
func HandleChatWatchEvent(cb func(ctx context.Context, payload codersdk.ChatWatchEvent, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("chat watch event pubsub: %w", err))
return
}
var payload codersdk.ChatWatchEvent
if err := json.Unmarshal(message, &payload); err != nil {
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("unmarshal chat watch event: %w", err))
return
}
cb(ctx, payload, err)
}
}
+129 -20
View File
@@ -238,6 +238,13 @@ type turnWorkspaceContext struct {
conn workspacesdk.AgentConn
releaseConn func()
cachedWorkspaceID uuid.NullUUID
// readinessOnce ensures only one goroutine polls the
// workspace build status when multiple tools call
// getWorkspaceConn concurrently on a building workspace.
// All callers share the single result.
readinessOnce sync.Once
readinessErr error
}
func (c *turnWorkspaceContext) close() {
@@ -252,6 +259,10 @@ func (c *turnWorkspaceContext) clearCachedWorkspaceState() {
c.conn = nil
c.releaseConn = nil
c.cachedWorkspaceID = uuid.NullUUID{}
// Reset readiness tracking so a new workspace (or
// rebuilt workspace) triggers a fresh readiness wait.
c.readinessOnce = sync.Once{}
c.readinessErr = nil
c.mu.Unlock()
if releaseConn != nil {
@@ -277,6 +288,96 @@ func (c *turnWorkspaceContext) selectWorkspace(chat database.Chat) {
c.clearCachedWorkspaceState()
}
// waitForWorkspaceReady blocks until the chat's workspace build
// completes and the agent comes online. Returns immediately if
// the workspace is already running. Concurrent callers share a
// single wait via sync.Once.
func (c *turnWorkspaceContext) waitForWorkspaceReady(ctx context.Context) error {
c.chatStateMu.Lock()
wsID := c.currentChat.WorkspaceID
c.chatStateMu.Unlock()
if !wsID.Valid {
return xerrors.New(
"no workspace associated with this chat; " +
"use create_workspace first",
)
}
db := c.server.db
// Fast path: if agents already exist in the latest build
// the workspace is usable regardless of any current build
// status. This covers workspaces that were pre-attached
// at chat creation time (already running) and workspaces
// where a new build was triggered on top of a running one.
agents, agentsErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
ctx, wsID.UUID,
)
if agentsErr == nil && len(agents) > 0 {
return nil
}
// No agents yet — check the build status.
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(
ctx, wsID.UUID,
)
if err != nil {
return xerrors.Errorf("check workspace build: %w", err)
}
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
if err != nil {
return xerrors.Errorf("check provisioner job: %w", err)
}
if job.JobStatus == database.ProvisionerJobStatusSucceeded {
return nil
}
if job.JobStatus == database.ProvisionerJobStatusFailed ||
job.JobStatus == database.ProvisionerJobStatusCanceled {
errMsg := "workspace build failed"
if job.Error.Valid {
errMsg = job.Error.String
}
return xerrors.New(errMsg)
}
// Slow path: workspace is building. Wait once, share result
// across concurrent tool calls within this turn.
c.readinessOnce.Do(func() {
c.readinessErr = c.doWaitForWorkspaceReady(
ctx, wsID.UUID,
)
})
return c.readinessErr
}
func (c *turnWorkspaceContext) doWaitForWorkspaceReady(
ctx context.Context, workspaceID uuid.UUID,
) error {
db := c.server.db
// Phase 1: wait for the provisioner build to complete.
if err := chattool.WaitForBuild(ctx, db, workspaceID); err != nil {
return xerrors.Errorf("workspace build: %w", err)
}
// Phase 2: resolve the agent and wait for it to come
// online. This is best-effort — if it times out, we
// still try to dial and let the dial give a better error.
agents, err := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
ctx, workspaceID,
)
if err != nil || len(agents) == 0 {
return nil
}
agentConnFn := chattool.AgentConnFunc(c.server.agentConnFn)
_ = chattool.WaitForAgentReady(
ctx, db, agents[0].ID, agentConnFn,
)
return nil
}
func (c *turnWorkspaceContext) currentWorkspaceMatches(expected uuid.NullUUID) (database.Chat, bool) {
chatSnapshot := c.currentChatSnapshot()
return chatSnapshot, nullUUIDEqual(chatSnapshot.WorkspaceID, expected)
@@ -537,6 +638,14 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces
return nil, xerrors.New("workspace agent connector is not configured")
}
// Wait for the workspace build to finish if it is still in
// progress. This transparently blocks workspace-dependent
// tools until the workspace is ready, without requiring the
// create_workspace tool to block.
if err := c.waitForWorkspaceReady(ctx); err != nil {
return nil, xerrors.Errorf("workspace not ready: %w", err)
}
for attempt := 0; attempt < 2; attempt++ {
c.mu.Lock()
currentConn, staleRelease := c.getWorkspaceConnLocked()
@@ -915,7 +1024,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
if opts.WorkspaceID.Valid {
workspaceAwareness = "This chat is attached to a workspace. You can use workspace tools like execute, read_file, write_file, etc."
} else {
workspaceAwareness = "There is no workspace associated with this chat yet. Create one using the create_workspace tool before using workspace tools like execute, read_file, write_file, etc."
workspaceAwareness = "There is no workspace associated with this chat yet. Create one using the create_workspace tool before using workspace tools. Workspace creation runs in the background — you can continue using other tools while it builds."
}
workspaceAwarenessContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(workspaceAwareness),
@@ -996,7 +1105,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
return database.Chat{}, txErr
}
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindCreated, nil)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil)
p.signalWake()
return chat, nil
}
@@ -1158,7 +1267,7 @@ func (p *Server) SendMessage(
p.publishMessage(opts.ChatID, result.Message)
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
p.signalWake()
return result, nil
}
@@ -1301,7 +1410,7 @@ func (p *Server) EditMessage(
QueueUpdate: true,
})
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
p.signalWake()
return result, nil
@@ -1355,10 +1464,10 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
if interrupted {
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
}
p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted)
p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted)
return nil
}
@@ -1373,7 +1482,7 @@ func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
ctx,
chat.ID,
"unarchive",
codersdk.ChatWatchEventKindCreated,
coderdpubsub.ChatEventKindCreated,
p.db.UnarchiveChatByID,
)
}
@@ -1382,7 +1491,7 @@ func (p *Server) applyChatLifecycleTransition(
ctx context.Context,
chatID uuid.UUID,
action string,
kind codersdk.ChatWatchEventKind,
kind coderdpubsub.ChatEventKind,
transition func(context.Context, uuid.UUID) ([]database.Chat, error),
) error {
updatedChats, err := transition(ctx, chatID)
@@ -1545,7 +1654,7 @@ func (p *Server) PromoteQueued(
})
p.publishMessage(opts.ChatID, promoted)
p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID)
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
p.signalWake()
return result, nil
@@ -2092,7 +2201,7 @@ func (p *Server) regenerateChatTitleWithStore(
return updatedChat, nil
}
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindTitleChange, nil)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil)
return updatedChat, nil
}
@@ -2347,7 +2456,7 @@ func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database
return database.Chat{}, err
}
p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID)
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
return updatedChat, nil
}
@@ -3627,7 +3736,7 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C
}
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) {
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) {
for _, chat := range chats {
p.publishChatPubsubEvent(chat, kind, nil)
}
@@ -3635,7 +3744,7 @@ func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.Ch
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
// pubsub so that all replicas can push updates to watching clients.
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) {
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) {
if p.pubsub == nil {
return
}
@@ -3647,7 +3756,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWa
if diffStatus != nil {
sdkChat.DiffStatus = diffStatus
}
event := codersdk.ChatWatchEvent{
event := coderdpubsub.ChatEvent{
Kind: kind,
Chat: sdkChat,
}
@@ -3659,7 +3768,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWa
)
return
}
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
p.logger.Error(context.Background(), "failed to publish chat pubsub event",
slog.F("chat_id", chat.ID),
slog.F("kind", kind),
@@ -3692,8 +3801,8 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
toolCalls := pendingToStreamToolCalls(pending)
sdkChat := db2sdk.Chat(chat, nil, nil)
event := codersdk.ChatWatchEvent{
Kind: codersdk.ChatWatchEventKindActionRequired,
event := coderdpubsub.ChatEvent{
Kind: coderdpubsub.ChatEventKindActionRequired,
Chat: sdkChat,
ToolCalls: toolCalls,
}
@@ -3705,7 +3814,7 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
)
return
}
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event",
slog.F("chat_id", chat.ID),
slog.Error(err),
@@ -3733,7 +3842,7 @@ func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID)
}
sdkStatus := db2sdk.ChatDiffStatus(chatID, &dbStatus)
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindDiffStatusChange, &sdkStatus)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDiffStatusChange, &sdkStatus)
return nil
}
@@ -4215,7 +4324,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
if title, ok := generatedTitle.Load(); ok {
updatedChat.Title = title
}
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
// When the chat is parked in requires_action,
// publish the stream event and global pubsub event
+51 -35
View File
@@ -71,14 +71,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
updatedChat.Title = wantTitle
messageEvents := make(chan struct {
payload codersdk.ChatWatchEvent
payload coderdpubsub.ChatEvent
err error
}, 1)
cancelSub, err := pubsub.SubscribeWithErr(
coderdpubsub.ChatWatchEventChannel(ownerID),
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
coderdpubsub.ChatEventChannel(ownerID),
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
messageEvents <- struct {
payload codersdk.ChatWatchEvent
payload coderdpubsub.ChatEvent
err error
}{payload: payload, err: err}
}),
@@ -184,7 +184,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
select {
case event := <-messageEvents:
require.NoError(t, event.err)
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
require.Equal(t, chatID, event.payload.Chat.ID)
require.Equal(t, wantTitle, event.payload.Chat.Title)
case <-time.After(time.Second):
@@ -234,14 +234,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
unlockedChat.StartedAt = sql.NullTime{}
messageEvents := make(chan struct {
payload codersdk.ChatWatchEvent
payload coderdpubsub.ChatEvent
err error
}, 1)
cancelSub, err := pubsub.SubscribeWithErr(
coderdpubsub.ChatWatchEventChannel(ownerID),
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
coderdpubsub.ChatEventChannel(ownerID),
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
messageEvents <- struct {
payload codersdk.ChatWatchEvent
payload coderdpubsub.ChatEvent
err error
}{payload: payload, err: err}
}),
@@ -373,7 +373,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
select {
case event := <-messageEvents:
require.NoError(t, event.err)
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
require.Equal(t, chatID, event.payload.Chat.ID)
require.Equal(t, wantTitle, event.payload.Chat.Title)
case <-time.After(time.Second):
@@ -568,32 +568,37 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
if !ok || arg.ID != chat.ID {
return false
}
if !arg.LastInjectedContext.Valid {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
return false
}
// Expect at least one context-file part for the
// working-directory AGENTS.md, with internal fields
// stripped (no content, OS, or directory).
for _, p := range parts {
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFilePath != "" {
return p.ContextFileContent == "" &&
p.ContextFileOS == "" &&
p.ContextFileDirectory == ""
// waitForWorkspaceReady checks for existing agents to
// fast-path workspaces that are already running.
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(), workspaceID,
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).AnyTimes()
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
if !ok || arg.ID != chat.ID {
return false
}
}
return false
}),
).Return(database.Chat{}, nil).Times(1)
if !arg.LastInjectedContext.Valid {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
return false
}
// Expect at least one context-file part for the
// working-directory AGENTS.md, with internal fields
// stripped (no content, OS, or directory).
for _, p := range parts {
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFilePath != "" {
return p.ContextFileContent == "" &&
p.ContextFileOS == "" &&
p.ContextFileDirectory == ""
}
}
return false
}),
).Return(database.Chat{}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
@@ -1069,6 +1074,9 @@ func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgen
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
gomock.InOrder(
// waitForWorkspaceReady checks for existing agents to
// fast-path workspaces that are already running.
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{staleAgent}, nil),
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(staleAgent, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
@@ -1137,9 +1145,17 @@ func TestTurnWorkspaceContextGetWorkspaceConnFastFailsWithoutCurrentAgent(t *tes
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
// waitForWorkspaceReady checks for existing agents first
// to fast-path running workspaces. Return the stale agent
// here so the readiness check passes.
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{staleAgent}, nil).
Times(1)
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).
Return(staleAgent, nil).
Times(1)
// The validation flow then calls this again and finds no
// agents, triggering errChatHasNoWorkspaceAgent.
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{}, nil).
Times(1)
+9 -12
View File
@@ -2744,25 +2744,22 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
created, ok := result["created"].(bool)
require.True(t, ok)
require.True(t, created)
// create_workspace now returns immediately with
// a "building" status instead of blocking until
// the build completes. Workspace tools wait
// transparently via getWorkspaceConn.
status, _ := result["status"].(string)
require.Equal(t, "building", status)
foundCreateWorkspaceResult = true
}
}
require.True(t, foundCreateWorkspaceResult, "expected create_workspace tool result message")
// Verify that the tool waited for startup scripts to
// complete. The agent should be in "ready" state by the
// time create_workspace returns its result.
// Since create_workspace returns immediately with
// status="building", the agent may not yet be ready.
// The workspace should still exist and be accessible.
workspace, err = client.Workspace(ctx, workspaceID)
require.NoError(t, err)
var agentLifecycle codersdk.WorkspaceAgentLifecycle
for _, res := range workspace.LatestBuild.Resources {
for _, agt := range res.Agents {
agentLifecycle = agt.LifecycleState
}
}
require.Equal(t, codersdk.WorkspaceAgentLifecycleReady, agentLifecycle,
"agent should be ready after create_workspace returns; startup scripts were not awaited")
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
streamedCallsMu.Lock()
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
+22 -246
View File
@@ -2,7 +2,6 @@ package chattool
import (
"context"
"errors"
"fmt"
"strings"
"sync"
@@ -20,31 +19,6 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// buildPollInterval is how often we check if the workspace
// build has completed.
buildPollInterval = 2 * time.Second
// buildTimeout is the maximum time to wait for a workspace
// build to complete before giving up.
buildTimeout = 10 * time.Minute
// agentConnectTimeout is the maximum time to wait for the
// workspace agent to become reachable after a successful build.
agentConnectTimeout = 2 * time.Minute
// agentRetryInterval is how often we retry connecting to the
// workspace agent.
agentRetryInterval = 2 * time.Second
// agentAttemptTimeout is the timeout for a single connection
// attempt to the workspace agent during the retry loop.
agentAttemptTimeout = 5 * time.Second
// startupScriptTimeout is the maximum time to wait for the
// workspace agent's startup scripts to finish after the agent
// is reachable.
startupScriptTimeout = 10 * time.Minute
// startupScriptPollInterval is how often we check the agent's
// lifecycle state while waiting for startup scripts.
startupScriptPollInterval = 2 * time.Second
)
// CreateWorkspaceFn creates a workspace for the given owner.
type CreateWorkspaceFn func(
ctx context.Context,
@@ -193,44 +167,9 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
// Wait for the build to complete and the agent to
// come online so subsequent tools can use the
// workspace immediately.
if options.DB != nil {
if err := waitForBuild(ctx, options.DB, workspace.ID); err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("workspace build failed: %w", err).Error(),
), nil
}
}
result := map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
}
// Select the chat agent so follow-up tools wait on the
// intended workspace agent.
workspaceAgentID := uuid.Nil
if options.DB != nil {
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
if agentErr == nil {
if len(agents) == 0 {
result["agent_status"] = "no_agent"
} else {
selected, selectErr := agentselect.FindChatAgent(agents)
if selectErr != nil {
result["agent_status"] = "selection_error"
result["agent_error"] = selectErr.Error()
} else {
workspaceAgentID = selected.ID
}
}
}
}
// Persist the workspace binding on the chat.
if options.DB != nil && options.ChatID != uuid.Nil {
// Persist the workspace binding immediately so that
// workspace-dependent tools can discover the workspace
// and wait for the build via getWorkspaceConn. if options.DB != nil && options.ChatID != uuid.Nil {
updatedChat, err := options.DB.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
ID: options.ChatID,
WorkspaceID: uuid.NullUUID{
@@ -255,16 +194,15 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
}
}
// Wait for the agent to come online and startup scripts to finish.
if workspaceAgentID != uuid.Nil {
agentStatus := waitForAgentReady(ctx, options.DB, workspaceAgentID, options.AgentConnFn)
for k, v := range agentStatus {
result[k] = v
}
}
return toolResponse(result), nil
})
// Return immediately — workspace tools will
// transparently wait for the build to complete via
// getWorkspaceConn when they are actually invoked.
return toolResponse(map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
"status": "building",
"message": "Workspace build started. Workspace tools will wait for it automatically.",
}), nil })
}
// checkExistingWorkspace checks whether the configured chat already has
@@ -316,35 +254,15 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace(
switch job.JobStatus {
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning:
// Build is in progress — wait for it instead of
// creating a new workspace.
if err := waitForBuild(ctx, db, ws.ID); err != nil {
return nil, false, xerrors.Errorf(
"existing workspace build failed: %w", err,
)
}
result := map[string]any{
// Build is in progress — return immediately so the
// agent can continue working. Workspace tools will
// wait for the build via getWorkspaceConn.
return map[string]any{
"created": false,
"workspace_name": ws.Name,
"status": "already_exists",
"message": "workspace build completed",
}
agents, agentsErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID)
if agentsErr == nil && len(agents) > 0 {
selected, selectErr := agentselect.FindChatAgent(agents)
if selectErr != nil {
o.Logger.Debug(ctx, "agent selection failed, falling back to first agent for readiness check",
slog.F("workspace_id", ws.ID),
slog.Error(selectErr),
)
selected = agents[0]
}
for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) {
result[k] = v
}
}
return result, true, nil
"workspace_name": ws.OwnerUsername + "/" + ws.Name,
"status": "building",
"message": "Workspace is currently building. Workspace tools will wait for it automatically.",
}, true, nil
case database.ProvisionerJobStatusSucceeded:
// If the workspace was stopped, tell the model to use
// start_workspace instead of creating a new one.
@@ -380,14 +298,12 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace(
switch status.Status {
case database.WorkspaceAgentStatusConnected:
result["message"] = "workspace is already running and recently connected"
for k, v := range waitForAgentReady(ctx, db, selected.ID, nil) {
result[k] = v
for k, v := range WaitForAgentReady(ctx, db, selected.ID, nil) { result[k] = v
}
return result, true, nil
case database.WorkspaceAgentStatusConnecting:
result["message"] = "workspace exists and the agent is still connecting"
for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) {
result[k] = v
for k, v := range WaitForAgentReady(ctx, db, selected.ID, agentConnFn) { result[k] = v
}
return result, true, nil
case database.WorkspaceAgentStatusDisconnected,
@@ -405,146 +321,6 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace(
}
}
// waitForBuild polls the workspace's latest build until it
// completes or the context expires.
func waitForBuild(
ctx context.Context,
db database.Store,
workspaceID uuid.UUID,
) error {
buildCtx, cancel := context.WithTimeout(ctx, buildTimeout)
defer cancel()
ticker := time.NewTicker(buildPollInterval)
defer ticker.Stop()
for {
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(
buildCtx, workspaceID,
)
if err != nil {
return xerrors.Errorf("get latest build: %w", err)
}
job, err := db.GetProvisionerJobByID(buildCtx, build.JobID)
if err != nil {
return xerrors.Errorf("get provisioner job: %w", err)
}
switch job.JobStatus {
case database.ProvisionerJobStatusSucceeded:
return nil
case database.ProvisionerJobStatusFailed:
errMsg := "build failed"
if job.Error.Valid {
errMsg = job.Error.String
}
return xerrors.New(errMsg)
case database.ProvisionerJobStatusCanceled:
return xerrors.New("build was canceled")
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning,
database.ProvisionerJobStatusCanceling:
// Still in progress — keep waiting.
default:
return xerrors.Errorf("unexpected job status: %s", job.JobStatus)
}
select {
case <-buildCtx.Done():
return xerrors.Errorf(
"timed out waiting for workspace build: %w",
buildCtx.Err(),
)
case <-ticker.C:
}
}
}
// waitForAgentReady waits for the workspace agent to become
// reachable and for its startup scripts to finish. It returns
// status fields suitable for merging into a tool response.
func waitForAgentReady(
ctx context.Context,
db database.Store,
agentID uuid.UUID,
agentConnFn AgentConnFunc,
) map[string]any {
result := map[string]any{}
// Phase 1: retry connecting to the agent.
if agentConnFn != nil {
agentCtx, agentCancel := context.WithTimeout(ctx, agentConnectTimeout)
defer agentCancel()
ticker := time.NewTicker(agentRetryInterval)
defer ticker.Stop()
var lastErr error
for {
attemptCtx, attemptCancel := context.WithTimeout(agentCtx, agentAttemptTimeout)
conn, release, err := agentConnFn(attemptCtx, agentID)
attemptCancel()
if err == nil {
release()
_ = conn
break
}
lastErr = err
select {
case <-agentCtx.Done():
result["agent_status"] = "not_ready"
result["agent_error"] = lastErr.Error()
return result
case <-ticker.C:
}
}
}
// Phase 2: poll lifecycle until startup scripts finish.
if db != nil {
scriptCtx, scriptCancel := context.WithTimeout(ctx, startupScriptTimeout)
defer scriptCancel()
ticker := time.NewTicker(startupScriptPollInterval)
defer ticker.Stop()
var lastState database.WorkspaceAgentLifecycleState
for {
row, err := db.GetWorkspaceAgentLifecycleStateByID(scriptCtx, agentID)
if err == nil {
lastState = row.LifecycleState
switch lastState {
case database.WorkspaceAgentLifecycleStateCreated,
database.WorkspaceAgentLifecycleStateStarting:
// Still in progress, keep polling.
case database.WorkspaceAgentLifecycleStateReady:
return result
default:
// Terminal non-ready state.
result["startup_scripts"] = "startup_scripts_failed"
result["lifecycle_state"] = string(lastState)
return result
}
}
select {
case <-scriptCtx.Done():
if errors.Is(scriptCtx.Err(), context.DeadlineExceeded) {
result["startup_scripts"] = "startup_scripts_timeout"
} else {
result["startup_scripts"] = "startup_scripts_unknown"
}
return result
case <-ticker.C:
}
}
}
return result
}
func generatedWorkspaceName(seed string) string {
base := codersdk.UsernameFrom(strings.TrimSpace(strings.ToLower(seed)))
if strings.TrimSpace(base) == "" {
@@ -44,7 +44,7 @@ func TestWaitForAgentReady(t *testing.T) {
return nil, func() {}, nil
}
result := waitForAgentReady(context.Background(), db, agentID, connFn)
result := WaitForAgentReady(context.Background(), db, agentID, connFn)
require.Empty(t, result)
})
@@ -63,7 +63,7 @@ func TestWaitForAgentReady(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
result := waitForAgentReady(ctx, db, agentID, connFn)
result := WaitForAgentReady(ctx, db, agentID, connFn)
require.Equal(t, "not_ready", result["agent_status"])
require.NotEmpty(t, result["agent_error"])
})
@@ -85,7 +85,7 @@ func TestWaitForAgentReady(t *testing.T) {
return nil, func() {}, nil
}
result := waitForAgentReady(context.Background(), db, agentID, connFn)
result := WaitForAgentReady(context.Background(), db, agentID, connFn)
require.Equal(t, "startup_scripts_failed", result["startup_scripts"])
require.Equal(t, "start_error", result["lifecycle_state"])
})
@@ -103,7 +103,7 @@ func TestWaitForAgentReady(t *testing.T) {
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}, nil)
result := waitForAgentReady(context.Background(), db, agentID, nil)
result := WaitForAgentReady(context.Background(), db, agentID, nil)
require.Empty(t, result)
})
@@ -114,7 +114,7 @@ func TestWaitForAgentReady(t *testing.T) {
return nil, func() {}, nil
}
result := waitForAgentReady(context.Background(), nil, uuid.New(), connFn)
result := WaitForAgentReady(context.Background(), nil, uuid.New(), connFn)
require.Empty(t, result)
})
}
@@ -334,8 +334,6 @@ func TestCreateWorkspace_GlobalTTL(t *testing.T) {
ownerID := uuid.New()
templateID := uuid.New()
workspaceID := uuid.New()
jobID := uuid.New()
db.EXPECT().
GetAuthorizationUserRoles(gomock.Any(), ownerID).
Return(database.GetAuthorizationUserRolesRow{
@@ -349,23 +347,6 @@ func TestCreateWorkspace_GlobalTTL(t *testing.T) {
GetChatWorkspaceTTL(gomock.Any()).
Return(tc.ttlReturn, tc.ttlErr)
db.EXPECT().
GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).
Return(database.WorkspaceBuild{
WorkspaceID: workspaceID,
JobID: jobID,
}, nil)
db.EXPECT().
GetProvisionerJobByID(gomock.Any(), jobID).
Return(database.ProvisionerJob{
ID: jobID,
JobStatus: database.ProvisionerJobStatusSucceeded,
}, nil)
db.EXPECT().
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{}, nil)
var capturedReq codersdk.CreateWorkspaceRequest
createFn := func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
capturedReq = req
+24 -58
View File
@@ -9,7 +9,6 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
"github.com/coder/coder/v2/codersdk"
)
@@ -95,22 +94,27 @@ func StartWorkspace(options StartWorkspaceOptions) fantasy.AgentTool {
), nil
}
// If a build is already in progress, wait for it.
switch job.JobStatus {
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning:
if err := waitForBuild(ctx, options.DB, ws.ID); err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("waiting for in-progress build: %w", err).Error(),
), nil
}
return waitForAgentAndRespond(ctx, options.DB, options.AgentConnFn, ws)
// Build already in progress — return immediately.
// Workspace tools will wait via getWorkspaceConn.
return toolResponse(map[string]any{
"started": true,
"workspace_name": ws.OwnerUsername + "/" + ws.Name,
"status": "building",
"message": "Workspace build is in progress. Workspace tools will wait for it automatically.",
}), nil
case database.ProvisionerJobStatusSucceeded:
// If the latest successful build is a start
// transition, the workspace should be running.
if build.Transition == database.WorkspaceTransitionStart {
return waitForAgentAndRespond(ctx, options.DB, options.AgentConnFn, ws)
// Already running — return immediately.
return toolResponse(map[string]any{
"started": true,
"workspace_name": ws.OwnerUsername + "/" + ws.Name,
"status": "running",
"message": "Workspace is already running.",
}), nil
}
// Otherwise it is stopped (or deleted) — proceed
// to start it below.
@@ -134,53 +138,15 @@ func StartWorkspace(options StartWorkspaceOptions) fantasy.AgentTool {
), nil
}
if err := waitForBuild(ctx, options.DB, ws.ID); err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("workspace start build failed: %w", err).Error(),
), nil
}
return waitForAgentAndRespond(ctx, options.DB, options.AgentConnFn, ws)
// Return immediately — workspace tools will
// transparently wait for the build to complete via
// getWorkspaceConn when they are actually invoked.
return toolResponse(map[string]any{
"started": true,
"workspace_name": ws.OwnerUsername + "/" + ws.Name,
"status": "starting",
"message": "Workspace start initiated. Workspace tools will wait for it automatically.",
}), nil
},
)
}
// waitForAgentAndRespond selects the chat agent from the workspace's
// latest build, waits for it to become reachable, and returns a
// success response.
func waitForAgentAndRespond(
ctx context.Context,
db database.Store,
agentConnFn AgentConnFunc,
ws database.Workspace,
) (fantasy.ToolResponse, error) {
agents, err := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID)
if err != nil || len(agents) == 0 {
// Workspace started but no agent found - still report
// success so the model knows the workspace is up.
return toolResponse(map[string]any{
"started": true,
"workspace_name": ws.Name,
"agent_status": "no_agent",
}), nil
}
selected, err := agentselect.FindChatAgent(agents)
if err != nil {
return toolResponse(map[string]any{
"started": true,
"workspace_name": ws.Name,
"agent_status": "selection_error",
"agent_error": err.Error(),
}), nil
}
result := map[string]any{
"started": true,
"workspace_name": ws.Name,
}
for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) {
result[k] = v
}
return toolResponse(result), nil
}
@@ -0,0 +1,178 @@
package chattool
import (
"context"
"errors"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
)
const (
// BuildPollInterval is how often we check if the workspace
// build has completed.
BuildPollInterval = 2 * time.Second
// BuildTimeout is the maximum time to wait for a workspace
// build to complete before giving up.
BuildTimeout = 10 * time.Minute
// AgentConnectTimeout is the maximum time to wait for the
// workspace agent to become reachable after a successful
// build.
AgentConnectTimeout = 2 * time.Minute
// AgentRetryInterval is how often we retry connecting to
// the workspace agent.
AgentRetryInterval = 2 * time.Second
// AgentAttemptTimeout is the timeout for a single connection
// attempt to the workspace agent during the retry loop.
AgentAttemptTimeout = 5 * time.Second
// StartupScriptTimeout is the maximum time to wait for the
// workspace agent's startup scripts to finish after the
// agent is reachable.
StartupScriptTimeout = 10 * time.Minute
// StartupScriptPollInterval is how often we check the
// agent's lifecycle state while waiting for startup scripts.
StartupScriptPollInterval = 2 * time.Second
)
// WaitForBuild polls the workspace's latest build until it
// completes or the context expires.
func WaitForBuild(
ctx context.Context,
db database.Store,
workspaceID uuid.UUID,
) error {
buildCtx, cancel := context.WithTimeout(ctx, BuildTimeout)
defer cancel()
ticker := time.NewTicker(BuildPollInterval)
defer ticker.Stop()
for {
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(
buildCtx, workspaceID,
)
if err != nil {
return xerrors.Errorf("get latest build: %w", err)
}
job, err := db.GetProvisionerJobByID(buildCtx, build.JobID)
if err != nil {
return xerrors.Errorf("get provisioner job: %w", err)
}
switch job.JobStatus {
case database.ProvisionerJobStatusSucceeded:
return nil
case database.ProvisionerJobStatusFailed:
errMsg := "build failed"
if job.Error.Valid {
errMsg = job.Error.String
}
return xerrors.New(errMsg)
case database.ProvisionerJobStatusCanceled:
return xerrors.New("build was canceled")
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning,
database.ProvisionerJobStatusCanceling:
// Still in progress — keep waiting.
default:
return xerrors.Errorf("unexpected job status: %s", job.JobStatus)
}
select {
case <-buildCtx.Done():
return xerrors.Errorf(
"timed out waiting for workspace build: %w",
buildCtx.Err(),
)
case <-ticker.C:
}
}
}
// WaitForAgentReady waits for the workspace agent to become
// reachable and for its startup scripts to finish. It returns
// status fields suitable for merging into a tool response.
func WaitForAgentReady(
ctx context.Context,
db database.Store,
agentID uuid.UUID,
agentConnFn AgentConnFunc,
) map[string]any {
result := map[string]any{}
// Phase 1: retry connecting to the agent.
if agentConnFn != nil {
agentCtx, agentCancel := context.WithTimeout(ctx, AgentConnectTimeout)
defer agentCancel()
ticker := time.NewTicker(AgentRetryInterval)
defer ticker.Stop()
var lastErr error
for {
attemptCtx, attemptCancel := context.WithTimeout(agentCtx, AgentAttemptTimeout)
conn, release, err := agentConnFn(attemptCtx, agentID)
attemptCancel()
if err == nil {
release()
_ = conn
break
}
lastErr = err
select {
case <-agentCtx.Done():
result["agent_status"] = "not_ready"
result["agent_error"] = lastErr.Error()
return result
case <-ticker.C:
}
}
}
// Phase 2: poll lifecycle until startup scripts finish.
if db != nil {
scriptCtx, scriptCancel := context.WithTimeout(ctx, StartupScriptTimeout)
defer scriptCancel()
ticker := time.NewTicker(StartupScriptPollInterval)
defer ticker.Stop()
var lastState database.WorkspaceAgentLifecycleState
for {
row, err := db.GetWorkspaceAgentLifecycleStateByID(scriptCtx, agentID)
if err == nil {
lastState = row.LifecycleState
switch lastState {
case database.WorkspaceAgentLifecycleStateCreated,
database.WorkspaceAgentLifecycleStateStarting:
// Still in progress, keep polling.
case database.WorkspaceAgentLifecycleStateReady:
return result
default:
// Terminal non-ready state.
result["startup_scripts"] = "startup_scripts_failed"
result["lifecycle_state"] = string(lastState)
return result
}
}
select {
case <-scriptCtx.Done():
if errors.Is(scriptCtx.Err(), context.DeadlineExceeded) {
result["startup_scripts"] = "startup_scripts_timeout"
} else {
result["startup_scripts"] = "startup_scripts_unknown"
}
return result
case <-ticker.C:
}
}
}
return result
}
+4
View File
@@ -86,6 +86,10 @@ Propose a plan when:
- The user asks for a plan.
If no workspace is attached to this chat yet, create and start one first using create_workspace and start_workspace.
Workspace creation is non-blocking the build runs in the background while you continue working.
Use this time to gather context with non-workspace tools (GitHub, web search, spawn_agent for research, etc.) and plan your approach.
Workspace tools (execute, read_file, write_file, etc.) will automatically wait for the workspace to be ready when called.
Once a workspace is available:
1. Use spawn_agent and wait_agent to research the codebase and gather context as needed.
2. Use write_file to create a Markdown plan file in the workspace (e.g. /home/coder/PLAN.md).
+2 -1
View File
@@ -21,6 +21,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
@@ -159,7 +160,7 @@ func (p *Server) maybeGenerateChatTitle(
}
chat.Title = title
generatedTitle.Store(title)
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindTitleChange, nil)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
return
}
+1 -1
View File
@@ -574,7 +574,7 @@ func (p *Server) createChildSubagentChatWithOptions(
return database.Chat{}, xerrors.Errorf("create child chat: %w", txErr)
}
p.publishChatPubsubEvent(child, codersdk.ChatWatchEventKindCreated, nil)
p.publishChatPubsubEvent(child, coderdpubsub.ChatEventKindCreated, nil)
p.signalWake()
return child, nil
}
+97 -13
View File
@@ -1130,6 +1130,11 @@ type ChatStreamEvent struct {
ActionRequired *ChatStreamActionRequired `json:"action_required,omitempty"`
}
type chatStreamEnvelope struct {
Type ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
// ChatCostSummaryOptions are optional query parameters for GetChatCostSummary.
type ChatCostSummaryOptions struct {
StartDate time.Time
@@ -1982,8 +1987,8 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
}()
for {
var batch []ChatStreamEvent
if err := wsjson.Read(streamCtx, conn, &batch); err != nil {
var envelope chatStreamEnvelope
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
if streamCtx.Err() != nil {
return
}
@@ -2000,10 +2005,61 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
return
}
for _, event := range batch {
if !send(event) {
switch envelope.Type {
case ServerSentEventTypePing:
continue
case ServerSentEventTypeData:
var batch []ChatStreamEvent
decodeErr := json.Unmarshal(envelope.Data, &batch)
if decodeErr == nil {
for _, streamedEvent := range batch {
if !send(streamedEvent) {
return
}
}
continue
}
{
_ = send(ChatStreamEvent{
Type: ChatStreamEventTypeError,
Error: &ChatStreamError{
Message: fmt.Sprintf(
"decode chat stream event batch: %v",
decodeErr,
),
},
})
return
}
case ServerSentEventTypeError:
message := "chat stream returned an error"
if len(envelope.Data) > 0 {
var response Response
if err := json.Unmarshal(envelope.Data, &response); err == nil {
message = formatChatStreamResponseError(response)
} else {
trimmed := strings.TrimSpace(string(envelope.Data))
if trimmed != "" {
message = trimmed
}
}
}
_ = send(ChatStreamEvent{
Type: ChatStreamEventTypeError,
Error: &ChatStreamError{
Message: message,
},
})
return
default:
_ = send(ChatStreamEvent{
Type: ChatStreamEventTypeError,
Error: &ChatStreamError{
Message: fmt.Sprintf("unknown chat stream event type %q", envelope.Type),
},
})
return
}
}
}()
@@ -2042,8 +2098,8 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv
}()
for {
var event ChatWatchEvent
if err := wsjson.Read(streamCtx, conn, &event); err != nil {
var envelope chatStreamEnvelope
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
if streamCtx.Err() != nil {
return
}
@@ -2054,10 +2110,23 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv
return
}
select {
case <-streamCtx.Done():
switch envelope.Type {
case ServerSentEventTypePing:
continue
case ServerSentEventTypeData:
var event ChatWatchEvent
if err := json.Unmarshal(envelope.Data, &event); err != nil {
return
}
select {
case <-streamCtx.Done():
return
case events <- event:
}
case ServerSentEventTypeError:
return
default:
return
case events <- event:
}
}
}()
@@ -2409,12 +2478,27 @@ func (c *ExperimentalClient) GetChatsByWorkspace(ctx context.Context, workspaceI
return result, json.NewDecoder(res.Body).Decode(&result)
}
func formatChatStreamResponseError(response Response) string {
message := strings.TrimSpace(response.Message)
detail := strings.TrimSpace(response.Detail)
switch {
case message == "" && detail == "":
return "chat stream returned an error"
case message == "":
return detail
case detail == "":
return message
default:
return fmt.Sprintf("%s: %s", message, detail)
}
}
// PRInsightsResponse is the response from the PR insights endpoint.
type PRInsightsResponse struct {
Summary PRInsightsSummary `json:"summary"`
TimeSeries []PRInsightsTimeSeriesEntry `json:"time_series"`
ByModel []PRInsightsModelBreakdown `json:"by_model"`
PullRequests []PRInsightsPullRequest `json:"recent_prs"`
Summary PRInsightsSummary `json:"summary"`
TimeSeries []PRInsightsTimeSeriesEntry `json:"time_series"`
ByModel []PRInsightsModelBreakdown `json:"by_model"`
RecentPRs []PRInsightsPullRequest `json:"recent_prs"`
}
// PRInsightsSummary contains aggregate PR metrics for a time period,
+18 -48
View File
@@ -75,49 +75,6 @@ func (req *OAuth2ClientRegistrationRequest) Validate() error {
return nil
}
// ValidateRedirectURIScheme reports whether the callback URL's scheme is
// safe to use as a redirect target. It returns an error when the scheme
// is empty, an unsupported URN, or one of the schemes that are dangerous
// in browser/HTML contexts (javascript, data, file, ftp).
//
// Legitimate custom schemes for native apps (e.g. vscode://, jetbrains://)
// are allowed.
// ValidateRedirectURIScheme reports whether the callback URL's scheme is
// safe to use as a redirect target. It returns an error when the scheme
// is empty, an unsupported URN, or one of the schemes that are dangerous
// in browser/HTML contexts (javascript, data, file, ftp).
//
// Legitimate custom schemes for native apps (e.g. vscode://, jetbrains://)
// are allowed.
func ValidateRedirectURIScheme(u *url.URL) error {
return validateScheme(u)
}
func validateScheme(u *url.URL) error {
if u.Scheme == "" {
return xerrors.New("redirect URI must have a scheme")
}
// Handle special URNs (RFC 6749 section 3.1.2.1).
if u.Scheme == "urn" {
if u.String() == "urn:ietf:wg:oauth:2.0:oob" {
return nil
}
return xerrors.New("redirect URI uses unsupported URN scheme")
}
// Block dangerous schemes for security (not allowed by RFCs
// for OAuth2).
dangerousSchemes := []string{"javascript", "data", "file", "ftp"}
for _, dangerous := range dangerousSchemes {
if strings.EqualFold(u.Scheme, dangerous) {
return xerrors.Errorf("redirect URI uses dangerous scheme %s which is not allowed", dangerous)
}
}
return nil
}
// validateRedirectURIs validates redirect URIs according to RFC 7591, 8252
func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod) error {
if len(uris) == 0 {
@@ -134,14 +91,27 @@ func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndp
return xerrors.Errorf("redirect URI at index %d is not a valid URL: %w", i, err)
}
if err := validateScheme(uri); err != nil {
return xerrors.Errorf("redirect URI at index %d: %w", i, err)
// Validate schemes according to RFC requirements
if uri.Scheme == "" {
return xerrors.Errorf("redirect URI at index %d must have a scheme", i)
}
// The urn:ietf:wg:oauth:2.0:oob scheme passed validation
// above but needs no further checks.
// Handle special URNs (RFC 6749 section 3.1.2.1)
if uri.Scheme == "urn" {
continue
// Allow the out-of-band redirect URI for native apps
if uriStr == "urn:ietf:wg:oauth:2.0:oob" {
continue // This is valid for native apps
}
// Other URNs are not standard for OAuth2
return xerrors.Errorf("redirect URI at index %d uses unsupported URN scheme", i)
}
// Block dangerous schemes for security (not allowed by RFCs for OAuth2)
dangerousSchemes := []string{"javascript", "data", "file", "ftp"}
for _, dangerous := range dangerousSchemes {
if strings.EqualFold(uri.Scheme, dangerous) {
return xerrors.Errorf("redirect URI at index %d uses dangerous scheme %s which is not allowed", i, dangerous)
}
}
// Determine if this is a public client based on token endpoint auth method
+7 -8
View File
@@ -143,14 +143,13 @@ type ProvisionerJobInput struct {
// ProvisionerJobMetadata contains metadata for the job.
type ProvisionerJobMetadata struct {
TemplateVersionName string `json:"template_version_name" table:"template version name"`
TemplateID uuid.UUID `json:"template_id" format:"uuid" table:"template id"`
TemplateName string `json:"template_name" table:"template name"`
TemplateDisplayName string `json:"template_display_name" table:"template display name"`
TemplateIcon string `json:"template_icon" table:"template icon"`
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid" table:"workspace id"`
WorkspaceName string `json:"workspace_name,omitempty" table:"workspace name"`
WorkspaceBuildTransition WorkspaceTransition `json:"workspace_build_transition,omitempty" table:"workspace build transition"`
TemplateVersionName string `json:"template_version_name" table:"template version name"`
TemplateID uuid.UUID `json:"template_id" format:"uuid" table:"template id"`
TemplateName string `json:"template_name" table:"template name"`
TemplateDisplayName string `json:"template_display_name" table:"template display name"`
TemplateIcon string `json:"template_icon" table:"template icon"`
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid" table:"workspace id"`
WorkspaceName string `json:"workspace_name,omitempty" table:"workspace name"`
}
// ProvisionerJobType represents the type of job.
+2 -17
View File
@@ -40,7 +40,7 @@ CODER_EXPERIMENTS=oauth2
2. Click **Create Application**
3. Fill in the application details:
- **Name**: Your application name
- **Callback URL**: `https://yourapp.example.com/callback` (web) or `myapp://callback` (native/desktop)
- **Callback URL**: `https://yourapp.example.com/callback`
- **Icon**: Optional icon URL
### Method 2: Management API
@@ -251,31 +251,16 @@ Add `oauth2` to your experiment flags: `coder server --experiments oauth2`
Ensure the redirect URI in your request exactly matches the one registered for your application.
### "Invalid Callback URL" on the consent page
If you see this error when authorizing, the registered callback URL uses a
blocked scheme (`javascript:`, `data:`, `file:`, or `ftp:`). Update the
application's callback URL to a valid scheme (see
[Callback URL schemes](#callback-url-schemes)).
### "PKCE verification failed"
Verify that the `code_verifier` used in the token request matches the one used to generate the `code_challenge`.
## Callback URL schemes
Custom URI schemes (`myapp://`, `vscode://`, `jetbrains://`, etc.) are fully supported for native and desktop applications. The OS routes the redirect back to the registered application without requiring a running HTTP server.
The following schemes are blocked for security reasons: `javascript:`, `data:`, `file:`, `ftp:`.
## Security Considerations
- **Use HTTPS**: Always use HTTPS in production to protect tokens in transit
- **Implement PKCE**: PKCE is mandatory for all authorization code clients
(public and confidential)
- **Validate redirect URLs**: Only register trusted redirect URIs. Dangerous
schemes (`javascript:`, `data:`, `file:`, `ftp:`) are blocked by the server,
but custom URI schemes for native apps (`myapp://`) are permitted
- **Validate redirect URLs**: Only register trusted redirect URIs for your applications
- **Rotate secrets**: Periodically rotate client secrets using the management API
## Limitations
@@ -23,7 +23,6 @@ The following database fields are currently encrypted:
- `external_auth_links.oauth_access_token`
- `external_auth_links.oauth_refresh_token`
- `crypto_keys.secret`
- `user_secrets.value`
Additional database fields may be encrypted in the future.
@@ -80,19 +80,9 @@ See [Proxy TLS Configuration](#proxy-tls-configuration) for configuration steps.
### Restricting proxy access
Requests to non-allowlisted domains are tunneled through the proxy, but connections to private and reserved IP ranges are blocked by default.
The IP validation and TCP connect happen atomically, preventing DNS rebinding attacks where the resolved address could change between the check and the connection.
Requests to non-allowlisted domains are tunneled through the proxy without restriction.
To prevent unauthorized use, restrict network access to the proxy so that only authorized clients can connect.
In case the Coder access URL resolves to a private address, it is automatically exempt from this restriction so the proxy can always reach its own deployment.
If you need to allow access to additional internal networks via the proxy, use the Allowlist CIDRs option ([`CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS`](../../../reference/cli/server.md#--aibridge-proxy-allowed-private-cidrs)):
```shell
CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS=10.0.0.0/8,172.16.0.0/12
# or via CLI flag:
--aibridge-proxy-allowed-private-cidrs=10.0.0.0/8,172.16.0.0/12
```
## CA Certificate
AI Gateway Proxy uses a CA (Certificate Authority) certificate to perform MITM interception of HTTPS traffic.
@@ -250,11 +240,6 @@ To ensure AI Gateway also routes requests through the upstream proxy, make sure
<!-- TODO(ssncferreira): Add diagram showing how AI Gateway Proxy integrates with upstream proxies -->
> [!NOTE]
> When an upstream proxy is configured, AI Gateway Proxy validates the destination IP before forwarding the request.
> However, the upstream proxy re-resolves DNS independently, so a small DNS rebinding window exists between the validation and the actual connection.
> Ensure your upstream proxy enforces its own restrictions on private and reserved IP ranges.
### Configuration
Configure the upstream proxy URL:
+14 -21
View File
@@ -60,7 +60,6 @@ curl -X GET http://coder-server:8080/api/v2/users/{user}/workspace/{workspacenam
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -301,7 +300,6 @@ curl -X GET http://coder-server:8080/api/v2/workspacebuilds/{workspacebuild} \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1010,7 +1008,6 @@ curl -X GET http://coder-server:8080/api/v2/workspacebuilds/{workspacebuild}/sta
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1362,7 +1359,6 @@ curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace}/builds \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1581,7 +1577,6 @@ Status Code **200**
| `»»» template_id` | string(uuid) | false | | |
| `»»» template_name` | string | false | | |
| `»»» template_version_name` | string | false | | |
| `»»» workspace_build_transition` | [codersdk.WorkspaceTransition](schemas.md#codersdkworkspacetransition) | false | | |
| `»»» workspace_id` | string(uuid) | false | | |
| `»»» workspace_name` | string | false | | |
| `»» organization_id` | string(uuid) | false | | |
@@ -1715,21 +1710,20 @@ Status Code **200**
#### Enumerated Values
| Property | Value(s) |
|------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `error_code` | `REQUIRED_TEMPLATE_VARIABLES` |
| `workspace_build_transition` | `delete`, `start`, `stop` |
| `status` | `canceled`, `canceling`, `connected`, `connecting`, `deleted`, `deleting`, `disconnected`, `failed`, `pending`, `running`, `starting`, `stopped`, `stopping`, `succeeded`, `timeout` |
| `type` | `template_version_dry_run`, `template_version_import`, `workspace_build` |
| `reason` | `autostart`, `autostop`, `initiator` |
| `health` | `disabled`, `healthy`, `initializing`, `unhealthy` |
| `open_in` | `slim-window`, `tab` |
| `sharing_level` | `authenticated`, `organization`, `owner`, `public` |
| `state` | `complete`, `failure`, `idle`, `working` |
| `lifecycle_state` | `created`, `off`, `ready`, `shutdown_error`, `shutdown_timeout`, `shutting_down`, `start_error`, `start_timeout`, `starting` |
| `startup_script_behavior` | `blocking`, `non-blocking` |
| `workspace_transition` | `delete`, `start`, `stop` |
| `transition` | `delete`, `start`, `stop` |
| Property | Value(s) |
|---------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `error_code` | `REQUIRED_TEMPLATE_VARIABLES` |
| `status` | `canceled`, `canceling`, `connected`, `connecting`, `deleted`, `deleting`, `disconnected`, `failed`, `pending`, `running`, `starting`, `stopped`, `stopping`, `succeeded`, `timeout` |
| `type` | `template_version_dry_run`, `template_version_import`, `workspace_build` |
| `reason` | `autostart`, `autostop`, `initiator` |
| `health` | `disabled`, `healthy`, `initializing`, `unhealthy` |
| `open_in` | `slim-window`, `tab` |
| `sharing_level` | `authenticated`, `organization`, `owner`, `public` |
| `state` | `complete`, `failure`, `idle`, `working` |
| `lifecycle_state` | `created`, `off`, `ready`, `shutdown_error`, `shutdown_timeout`, `shutting_down`, `start_error`, `start_timeout`, `starting` |
| `startup_script_behavior` | `blocking`, `non-blocking` |
| `workspace_transition` | `delete`, `start`, `stop` |
| `transition` | `delete`, `start`, `stop` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
@@ -1816,7 +1810,6 @@ curl -X POST http://coder-server:8080/api/v2/workspaces/{workspace}/builds \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
+40 -44
View File
@@ -317,7 +317,6 @@ curl -X GET http://coder-server:8080/api/v2/organizations/{organization}/provisi
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -347,51 +346,49 @@ curl -X GET http://coder-server:8080/api/v2/organizations/{organization}/provisi
Status Code **200**
| Name | Type | Required | Restrictions | Description |
|---------------------------------|------------------------------------------------------------------------------|----------|--------------|-------------|
| `[array item]` | array | false | | |
| `» available_workers` | array | false | | |
| `» canceled_at` | string(date-time) | false | | |
| `» completed_at` | string(date-time) | false | | |
| `» created_at` | string(date-time) | false | | |
| `» error` | string | false | | |
| `» error_code` | [codersdk.JobErrorCode](schemas.md#codersdkjoberrorcode) | false | | |
| `» file_id` | string(uuid) | false | | |
| `» id` | string(uuid) | false | | |
| `» initiator_id` | string(uuid) | false | | |
| `» input` | [codersdk.ProvisionerJobInput](schemas.md#codersdkprovisionerjobinput) | false | | |
| `»» error` | string | false | | |
| `»» template_version_id` | string(uuid) | false | | |
| `»» workspace_build_id` | string(uuid) | false | | |
| `» logs_overflowed` | boolean | false | | |
| `» metadata` | [codersdk.ProvisionerJobMetadata](schemas.md#codersdkprovisionerjobmetadata) | false | | |
| `»» template_display_name` | string | false | | |
| `»» template_icon` | string | false | | |
| `»» template_id` | string(uuid) | false | | |
| `»» template_name` | string | false | | |
| `»» template_version_name` | string | false | | |
| `»» workspace_build_transition` | [codersdk.WorkspaceTransition](schemas.md#codersdkworkspacetransition) | false | | |
| `»» workspace_id` | string(uuid) | false | | |
| » workspace_name` | string | false | | |
| organization_id` | string(uuid) | false | | |
| `» queue_position` | integer | false | | |
| queue_size` | integer | false | | |
| `» started_at` | string(date-time) | false | | |
| status` | [codersdk.ProvisionerJobStatus](schemas.md#codersdkprovisionerjobstatus) | false | | |
| tags` | object | false | | |
| » [any property]` | string | false | | |
| type` | [codersdk.ProvisionerJobType](schemas.md#codersdkprovisionerjobtype) | false | | |
| `» worker_id` | string(uuid) | false | | |
| `» worker_name` | string | false | | |
| Name | Type | Required | Restrictions | Description |
|----------------------------|------------------------------------------------------------------------------|----------|--------------|-------------|
| `[array item]` | array | false | | |
| `» available_workers` | array | false | | |
| `» canceled_at` | string(date-time) | false | | |
| `» completed_at` | string(date-time) | false | | |
| `» created_at` | string(date-time) | false | | |
| `» error` | string | false | | |
| `» error_code` | [codersdk.JobErrorCode](schemas.md#codersdkjoberrorcode) | false | | |
| `» file_id` | string(uuid) | false | | |
| `» id` | string(uuid) | false | | |
| `» initiator_id` | string(uuid) | false | | |
| `» input` | [codersdk.ProvisionerJobInput](schemas.md#codersdkprovisionerjobinput) | false | | |
| `»» error` | string | false | | |
| `»» template_version_id` | string(uuid) | false | | |
| `»» workspace_build_id` | string(uuid) | false | | |
| `» logs_overflowed` | boolean | false | | |
| `» metadata` | [codersdk.ProvisionerJobMetadata](schemas.md#codersdkprovisionerjobmetadata) | false | | |
| `»» template_display_name` | string | false | | |
| `»» template_icon` | string | false | | |
| `»» template_id` | string(uuid) | false | | |
| `»» template_name` | string | false | | |
| `»» template_version_name` | string | false | | |
| `»» workspace_id` | string(uuid) | false | | |
| `»» workspace_name` | string | false | | |
| organization_id` | string(uuid) | false | | |
| queue_position` | integer | false | | |
| `» queue_size` | integer | false | | |
| started_at` | string(date-time) | false | | |
| `» status` | [codersdk.ProvisionerJobStatus](schemas.md#codersdkprovisionerjobstatus) | false | | |
| `» tags` | object | false | | |
| » [any property]` | string | false | | |
| type` | [codersdk.ProvisionerJobType](schemas.md#codersdkprovisionerjobtype) | false | | |
| worker_id` | string(uuid) | false | | |
| `» worker_name` | string | false | | |
#### Enumerated Values
| Property | Value(s) |
|------------------------------|--------------------------------------------------------------------------|
| `error_code` | `REQUIRED_TEMPLATE_VARIABLES` |
| `workspace_build_transition` | `delete`, `start`, `stop` |
| `status` | `canceled`, `canceling`, `failed`, `pending`, `running`, `succeeded` |
| `type` | `template_version_dry_run`, `template_version_import`, `workspace_build` |
| Property | Value(s) |
|--------------|--------------------------------------------------------------------------|
| `error_code` | `REQUIRED_TEMPLATE_VARIABLES` |
| `status` | `canceled`, `canceling`, `failed`, `pending`, `running`, `succeeded` |
| `type` | `template_version_dry_run`, `template_version_import`, `workspace_build` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
@@ -444,7 +441,6 @@ curl -X GET http://coder-server:8080/api/v2/organizations/{organization}/provisi
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
+9 -18
View File
@@ -7121,7 +7121,6 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -7788,7 +7787,6 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -7898,7 +7896,6 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
}
@@ -7906,16 +7903,15 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
### Properties
| Name | Type | Required | Restrictions | Description |
|------------------------------|--------------------------------------------------------------|----------|--------------|-------------|
| `template_display_name` | string | false | | |
| `template_icon` | string | false | | |
| `template_id` | string | false | | |
| `template_name` | string | false | | |
| `template_version_name` | string | false | | |
| `workspace_build_transition` | [codersdk.WorkspaceTransition](#codersdkworkspacetransition) | false | | |
| `workspace_id` | string | false | | |
| `workspace_name` | string | false | | |
| Name | Type | Required | Restrictions | Description |
|-------------------------|--------|----------|--------------|-------------|
| `template_display_name` | string | false | | |
| `template_icon` | string | false | | |
| `template_id` | string | false | | |
| `template_name` | string | false | | |
| `template_version_name` | string | false | | |
| `workspace_id` | string | false | | |
| `workspace_name` | string | false | | |
## codersdk.ProvisionerJobStatus
@@ -8471,7 +8467,6 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -10019,7 +10014,6 @@ Restarts will only happen on weekdays in this list on weeks which line up with W
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -11410,7 +11404,6 @@ If the schedule is empty, the user will be updated to use the default schedule.|
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -12569,7 +12562,6 @@ If the schedule is empty, the user will be updated to use the default schedule.|
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -13402,7 +13394,6 @@ If the schedule is empty, the user will be updated to use the default schedule.|
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
-2
View File
@@ -425,7 +425,6 @@ curl -X POST http://coder-server:8080/api/v2/tasks/{user}/{task}/pause \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -669,7 +668,6 @@ curl -X POST http://coder-server:8080/api/v2/tasks/{user}/{task}/resume \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
+122 -135
View File
@@ -493,7 +493,6 @@ curl -X GET http://coder-server:8080/api/v2/organizations/{organization}/templat
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -596,7 +595,6 @@ curl -X GET http://coder-server:8080/api/v2/organizations/{organization}/templat
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -723,7 +721,6 @@ curl -X POST http://coder-server:8080/api/v2/organizations/{organization}/templa
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1338,7 +1335,6 @@ curl -X GET http://coder-server:8080/api/v2/templates/{template}/versions \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1383,72 +1379,70 @@ curl -X GET http://coder-server:8080/api/v2/templates/{template}/versions \
Status Code **200**
| Name | Type | Required | Restrictions | Description |
|----------------------------------|------------------------------------------------------------------------------|----------|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `[array item]` | array | false | | |
| `» archived` | boolean | false | | |
| `» created_at` | string(date-time) | false | | |
| `» created_by` | [codersdk.MinimalUser](schemas.md#codersdkminimaluser) | false | | |
| `»» avatar_url` | string(uri) | false | | |
| `»» id` | string(uuid) | true | | |
| `»» name` | string | false | | |
| `»» username` | string | true | | |
| `» has_external_agent` | boolean | false | | |
| `» id` | string(uuid) | false | | |
| `» job` | [codersdk.ProvisionerJob](schemas.md#codersdkprovisionerjob) | false | | |
| `»» available_workers` | array | false | | |
| `»» canceled_at` | string(date-time) | false | | |
| `»» completed_at` | string(date-time) | false | | |
| `»» created_at` | string(date-time) | false | | |
| `»» error` | string | false | | |
| `»» error_code` | [codersdk.JobErrorCode](schemas.md#codersdkjoberrorcode) | false | | |
| `»» file_id` | string(uuid) | false | | |
| `»» id` | string(uuid) | false | | |
| `»» initiator_id` | string(uuid) | false | | |
| `»» input` | [codersdk.ProvisionerJobInput](schemas.md#codersdkprovisionerjobinput) | false | | |
| `»»» error` | string | false | | |
| `»»» template_version_id` | string(uuid) | false | | |
| `»»» workspace_build_id` | string(uuid) | false | | |
| `»» logs_overflowed` | boolean | false | | |
| `»» metadata` | [codersdk.ProvisionerJobMetadata](schemas.md#codersdkprovisionerjobmetadata) | false | | |
| `»»» template_display_name` | string | false | | |
| `»»» template_icon` | string | false | | |
| `»»» template_id` | string(uuid) | false | | |
| `»»» template_name` | string | false | | |
| `»»» template_version_name` | string | false | | |
| `»»» workspace_build_transition` | [codersdk.WorkspaceTransition](schemas.md#codersdkworkspacetransition) | false | | |
| `»»» workspace_id` | string(uuid) | false | | |
| `»»» workspace_name` | string | false | | |
| `»» organization_id` | string(uuid) | false | | |
| `»» queue_position` | integer | false | | |
| `»» queue_size` | integer | false | | |
| `»» started_at` | string(date-time) | false | | |
| `»» status` | [codersdk.ProvisionerJobStatus](schemas.md#codersdkprovisionerjobstatus) | false | | |
| `»» tags` | object | false | | |
| `»»» [any property]` | string | false | | |
| `»» type` | [codersdk.ProvisionerJobType](schemas.md#codersdkprovisionerjobtype) | false | | |
| `»» worker_id` | string(uuid) | false | | |
| » worker_name` | string | false | | |
| matched_provisioners` | [codersdk.MatchedProvisioners](schemas.md#codersdkmatchedprovisioners) | false | | |
| `»» available` | integer | false | | Available is the number of provisioner daemons that are available to take jobs. This may be less than the count if some provisioners are busy or have been stopped. |
| `»» count` | integer | false | | Count is the number of provisioner daemons that matched the given tags. If the count is 0, it means no provisioner daemons matched the requested tags. |
| » most_recently_seen` | string(date-time) | false | | Most recently seen is the most recently seen time of the set of matched provisioners. If no provisioners matched, this field will be null. |
| message` | string | false | | |
| name` | string | false | | |
| organization_id` | string(uuid) | false | | |
| readme` | string | false | | |
| template_id` | string(uuid) | false | | |
| updated_at` | string(date-time) | false | | |
| `» warnings` | array | false | | |
| Name | Type | Required | Restrictions | Description |
|-----------------------------|------------------------------------------------------------------------------|----------|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `[array item]` | array | false | | |
| `» archived` | boolean | false | | |
| `» created_at` | string(date-time) | false | | |
| `» created_by` | [codersdk.MinimalUser](schemas.md#codersdkminimaluser) | false | | |
| `»» avatar_url` | string(uri) | false | | |
| `»» id` | string(uuid) | true | | |
| `»» name` | string | false | | |
| `»» username` | string | true | | |
| `» has_external_agent` | boolean | false | | |
| `» id` | string(uuid) | false | | |
| `» job` | [codersdk.ProvisionerJob](schemas.md#codersdkprovisionerjob) | false | | |
| `»» available_workers` | array | false | | |
| `»» canceled_at` | string(date-time) | false | | |
| `»» completed_at` | string(date-time) | false | | |
| `»» created_at` | string(date-time) | false | | |
| `»» error` | string | false | | |
| `»» error_code` | [codersdk.JobErrorCode](schemas.md#codersdkjoberrorcode) | false | | |
| `»» file_id` | string(uuid) | false | | |
| `»» id` | string(uuid) | false | | |
| `»» initiator_id` | string(uuid) | false | | |
| `»» input` | [codersdk.ProvisionerJobInput](schemas.md#codersdkprovisionerjobinput) | false | | |
| `»»» error` | string | false | | |
| `»»» template_version_id` | string(uuid) | false | | |
| `»»» workspace_build_id` | string(uuid) | false | | |
| `»» logs_overflowed` | boolean | false | | |
| `»» metadata` | [codersdk.ProvisionerJobMetadata](schemas.md#codersdkprovisionerjobmetadata) | false | | |
| `»»» template_display_name` | string | false | | |
| `»»» template_icon` | string | false | | |
| `»»» template_id` | string(uuid) | false | | |
| `»»» template_name` | string | false | | |
| `»»» template_version_name` | string | false | | |
| `»»» workspace_id` | string(uuid) | false | | |
| `»»» workspace_name` | string | false | | |
| `»» organization_id` | string(uuid) | false | | |
| `»» queue_position` | integer | false | | |
| `»» queue_size` | integer | false | | |
| `»» started_at` | string(date-time) | false | | |
| `»» status` | [codersdk.ProvisionerJobStatus](schemas.md#codersdkprovisionerjobstatus) | false | | |
| `»» tags` | object | false | | |
| `»»» [any property]` | string | false | | |
| `»» type` | [codersdk.ProvisionerJobType](schemas.md#codersdkprovisionerjobtype) | false | | |
| `»» worker_id` | string(uuid) | false | | |
| `»» worker_name` | string | false | | |
| matched_provisioners` | [codersdk.MatchedProvisioners](schemas.md#codersdkmatchedprovisioners) | false | | |
| » available` | integer | false | | Available is the number of provisioner daemons that are available to take jobs. This may be less than the count if some provisioners are busy or have been stopped. |
| `»» count` | integer | false | | Count is the number of provisioner daemons that matched the given tags. If the count is 0, it means no provisioner daemons matched the requested tags. |
| `»» most_recently_seen` | string(date-time) | false | | Most recently seen is the most recently seen time of the set of matched provisioners. If no provisioners matched, this field will be null. |
| `» message` | string | false | | |
| name` | string | false | | |
| organization_id` | string(uuid) | false | | |
| readme` | string | false | | |
| template_id` | string(uuid) | false | | |
| updated_at` | string(date-time) | false | | |
| warnings` | array | false | | |
#### Enumerated Values
| Property | Value(s) |
|------------------------------|--------------------------------------------------------------------------|
| `error_code` | `REQUIRED_TEMPLATE_VARIABLES` |
| `workspace_build_transition` | `delete`, `start`, `stop` |
| `status` | `canceled`, `canceling`, `failed`, `pending`, `running`, `succeeded` |
| `type` | `template_version_dry_run`, `template_version_import`, `workspace_build` |
| Property | Value(s) |
|--------------|--------------------------------------------------------------------------|
| `error_code` | `REQUIRED_TEMPLATE_VARIABLES` |
| `status` | `canceled`, `canceling`, `failed`, `pending`, `running`, `succeeded` |
| `type` | `template_version_dry_run`, `template_version_import`, `workspace_build` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
@@ -1621,7 +1615,6 @@ curl -X GET http://coder-server:8080/api/v2/templates/{template}/versions/{templ
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1666,72 +1659,70 @@ curl -X GET http://coder-server:8080/api/v2/templates/{template}/versions/{templ
Status Code **200**
| Name | Type | Required | Restrictions | Description |
|----------------------------------|------------------------------------------------------------------------------|----------|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `[array item]` | array | false | | |
| `» archived` | boolean | false | | |
| `» created_at` | string(date-time) | false | | |
| `» created_by` | [codersdk.MinimalUser](schemas.md#codersdkminimaluser) | false | | |
| `»» avatar_url` | string(uri) | false | | |
| `»» id` | string(uuid) | true | | |
| `»» name` | string | false | | |
| `»» username` | string | true | | |
| `» has_external_agent` | boolean | false | | |
| `» id` | string(uuid) | false | | |
| `» job` | [codersdk.ProvisionerJob](schemas.md#codersdkprovisionerjob) | false | | |
| `»» available_workers` | array | false | | |
| `»» canceled_at` | string(date-time) | false | | |
| `»» completed_at` | string(date-time) | false | | |
| `»» created_at` | string(date-time) | false | | |
| `»» error` | string | false | | |
| `»» error_code` | [codersdk.JobErrorCode](schemas.md#codersdkjoberrorcode) | false | | |
| `»» file_id` | string(uuid) | false | | |
| `»» id` | string(uuid) | false | | |
| `»» initiator_id` | string(uuid) | false | | |
| `»» input` | [codersdk.ProvisionerJobInput](schemas.md#codersdkprovisionerjobinput) | false | | |
| `»»» error` | string | false | | |
| `»»» template_version_id` | string(uuid) | false | | |
| `»»» workspace_build_id` | string(uuid) | false | | |
| `»» logs_overflowed` | boolean | false | | |
| `»» metadata` | [codersdk.ProvisionerJobMetadata](schemas.md#codersdkprovisionerjobmetadata) | false | | |
| `»»» template_display_name` | string | false | | |
| `»»» template_icon` | string | false | | |
| `»»» template_id` | string(uuid) | false | | |
| `»»» template_name` | string | false | | |
| `»»» template_version_name` | string | false | | |
| `»»» workspace_build_transition` | [codersdk.WorkspaceTransition](schemas.md#codersdkworkspacetransition) | false | | |
| `»»» workspace_id` | string(uuid) | false | | |
| `»»» workspace_name` | string | false | | |
| `»» organization_id` | string(uuid) | false | | |
| `»» queue_position` | integer | false | | |
| `»» queue_size` | integer | false | | |
| `»» started_at` | string(date-time) | false | | |
| `»» status` | [codersdk.ProvisionerJobStatus](schemas.md#codersdkprovisionerjobstatus) | false | | |
| `»» tags` | object | false | | |
| `»»» [any property]` | string | false | | |
| `»» type` | [codersdk.ProvisionerJobType](schemas.md#codersdkprovisionerjobtype) | false | | |
| `»» worker_id` | string(uuid) | false | | |
| » worker_name` | string | false | | |
| matched_provisioners` | [codersdk.MatchedProvisioners](schemas.md#codersdkmatchedprovisioners) | false | | |
| `»» available` | integer | false | | Available is the number of provisioner daemons that are available to take jobs. This may be less than the count if some provisioners are busy or have been stopped. |
| `»» count` | integer | false | | Count is the number of provisioner daemons that matched the given tags. If the count is 0, it means no provisioner daemons matched the requested tags. |
| » most_recently_seen` | string(date-time) | false | | Most recently seen is the most recently seen time of the set of matched provisioners. If no provisioners matched, this field will be null. |
| message` | string | false | | |
| name` | string | false | | |
| organization_id` | string(uuid) | false | | |
| readme` | string | false | | |
| template_id` | string(uuid) | false | | |
| updated_at` | string(date-time) | false | | |
| `» warnings` | array | false | | |
| Name | Type | Required | Restrictions | Description |
|-----------------------------|------------------------------------------------------------------------------|----------|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `[array item]` | array | false | | |
| `» archived` | boolean | false | | |
| `» created_at` | string(date-time) | false | | |
| `» created_by` | [codersdk.MinimalUser](schemas.md#codersdkminimaluser) | false | | |
| `»» avatar_url` | string(uri) | false | | |
| `»» id` | string(uuid) | true | | |
| `»» name` | string | false | | |
| `»» username` | string | true | | |
| `» has_external_agent` | boolean | false | | |
| `» id` | string(uuid) | false | | |
| `» job` | [codersdk.ProvisionerJob](schemas.md#codersdkprovisionerjob) | false | | |
| `»» available_workers` | array | false | | |
| `»» canceled_at` | string(date-time) | false | | |
| `»» completed_at` | string(date-time) | false | | |
| `»» created_at` | string(date-time) | false | | |
| `»» error` | string | false | | |
| `»» error_code` | [codersdk.JobErrorCode](schemas.md#codersdkjoberrorcode) | false | | |
| `»» file_id` | string(uuid) | false | | |
| `»» id` | string(uuid) | false | | |
| `»» initiator_id` | string(uuid) | false | | |
| `»» input` | [codersdk.ProvisionerJobInput](schemas.md#codersdkprovisionerjobinput) | false | | |
| `»»» error` | string | false | | |
| `»»» template_version_id` | string(uuid) | false | | |
| `»»» workspace_build_id` | string(uuid) | false | | |
| `»» logs_overflowed` | boolean | false | | |
| `»» metadata` | [codersdk.ProvisionerJobMetadata](schemas.md#codersdkprovisionerjobmetadata) | false | | |
| `»»» template_display_name` | string | false | | |
| `»»» template_icon` | string | false | | |
| `»»» template_id` | string(uuid) | false | | |
| `»»» template_name` | string | false | | |
| `»»» template_version_name` | string | false | | |
| `»»» workspace_id` | string(uuid) | false | | |
| `»»» workspace_name` | string | false | | |
| `»» organization_id` | string(uuid) | false | | |
| `»» queue_position` | integer | false | | |
| `»» queue_size` | integer | false | | |
| `»» started_at` | string(date-time) | false | | |
| `»» status` | [codersdk.ProvisionerJobStatus](schemas.md#codersdkprovisionerjobstatus) | false | | |
| `»» tags` | object | false | | |
| `»»» [any property]` | string | false | | |
| `»» type` | [codersdk.ProvisionerJobType](schemas.md#codersdkprovisionerjobtype) | false | | |
| `»» worker_id` | string(uuid) | false | | |
| `»» worker_name` | string | false | | |
| matched_provisioners` | [codersdk.MatchedProvisioners](schemas.md#codersdkmatchedprovisioners) | false | | |
| » available` | integer | false | | Available is the number of provisioner daemons that are available to take jobs. This may be less than the count if some provisioners are busy or have been stopped. |
| `»» count` | integer | false | | Count is the number of provisioner daemons that matched the given tags. If the count is 0, it means no provisioner daemons matched the requested tags. |
| `»» most_recently_seen` | string(date-time) | false | | Most recently seen is the most recently seen time of the set of matched provisioners. If no provisioners matched, this field will be null. |
| `» message` | string | false | | |
| name` | string | false | | |
| organization_id` | string(uuid) | false | | |
| readme` | string | false | | |
| template_id` | string(uuid) | false | | |
| updated_at` | string(date-time) | false | | |
| warnings` | array | false | | |
#### Enumerated Values
| Property | Value(s) |
|------------------------------|--------------------------------------------------------------------------|
| `error_code` | `REQUIRED_TEMPLATE_VARIABLES` |
| `workspace_build_transition` | `delete`, `start`, `stop` |
| `status` | `canceled`, `canceling`, `failed`, `pending`, `running`, `succeeded` |
| `type` | `template_version_dry_run`, `template_version_import`, `workspace_build` |
| Property | Value(s) |
|--------------|--------------------------------------------------------------------------|
| `error_code` | `REQUIRED_TEMPLATE_VARIABLES` |
| `status` | `canceled`, `canceling`, `failed`, `pending`, `running`, `succeeded` |
| `type` | `template_version_dry_run`, `template_version_import`, `workspace_build` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
@@ -1794,7 +1785,6 @@ curl -X GET http://coder-server:8080/api/v2/templateversions/{templateversion} \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1906,7 +1896,6 @@ curl -X PATCH http://coder-server:8080/api/v2/templateversions/{templateversion}
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -2106,7 +2095,6 @@ curl -X POST http://coder-server:8080/api/v2/templateversions/{templateversion}/
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -2182,7 +2170,6 @@ curl -X GET http://coder-server:8080/api/v2/templateversions/{templateversion}/d
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
-6
View File
@@ -115,7 +115,6 @@ of the template will be used.
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -479,7 +478,6 @@ curl -X GET http://coder-server:8080/api/v2/users/{user}/workspace/{workspacenam
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -810,7 +808,6 @@ of the template will be used.
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1119,7 +1116,6 @@ curl -X GET http://coder-server:8080/api/v2/workspaces \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1409,7 +1405,6 @@ curl -X GET http://coder-server:8080/api/v2/workspaces/{workspace} \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
@@ -1976,7 +1971,6 @@ curl -X PUT http://coder-server:8080/api/v2/workspaces/{workspace}/dormant \
"template_id": "c6d67e98-83ea-49f0-8812-e4abae2b68bc",
"template_name": "string",
"template_version_name": "string",
"workspace_build_transition": "start",
"workspace_id": "0967198e-ec7b-4c6b-b4d3-f71244cadbe9",
"workspace_name": "string"
},
+4 -4
View File
@@ -54,10 +54,10 @@ Select which organization (uuid or name) to use.
### -c, --column
| | |
|---------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Type | <code>[id\|created at\|started at\|completed at\|canceled at\|error\|error code\|status\|worker id\|worker name\|file id\|tags\|queue position\|queue size\|organization id\|initiator id\|template version id\|workspace build id\|type\|available workers\|template version name\|template id\|template name\|template display name\|template icon\|workspace id\|workspace name\|workspace build transition\|logs overflowed\|organization\|queue]</code> |
| Default | <code>created at,id,type,template display name,status,queue,tags</code> |
| | |
|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Type | <code>[id\|created at\|started at\|completed at\|canceled at\|error\|error code\|status\|worker id\|worker name\|file id\|tags\|queue position\|queue size\|organization id\|initiator id\|template version id\|workspace build id\|type\|available workers\|template version name\|template id\|template name\|template display name\|template icon\|workspace id\|workspace name\|logs overflowed\|organization\|queue]</code> |
| Default | <code>created at,id,type,template display name,status,queue,tags</code> |
Columns to display in table output.
-19
View File
@@ -197,10 +197,6 @@ func TestServerDBCrypt(t *testing.T) {
gitAuthLinks, err := db.GetExternalAuthLinksByUserID(ctx, usr.ID)
require.NoError(t, err, "failed to get git auth links for user %s", usr.ID)
require.Empty(t, gitAuthLinks)
userSecrets, err := db.ListUserSecretsWithValues(ctx, usr.ID)
require.NoError(t, err, "failed to get user secrets for user %s", usr.ID)
require.Empty(t, userSecrets)
}
// Validate that the key has been revoked in the database.
@@ -246,14 +242,6 @@ func genData(t *testing.T, db database.Store) []database.User {
OAuthRefreshToken: "refresh-" + usr.ID.String(),
})
}
_ = dbgen.UserSecret(t, db, database.UserSecret{
UserID: usr.ID,
Name: "secret-" + usr.ID.String(),
Value: "value-" + usr.ID.String(),
EnvName: "",
FilePath: "",
})
users = append(users, usr)
}
}
@@ -295,13 +283,6 @@ func requireEncryptedWithCipher(ctx context.Context, t *testing.T, db database.S
require.Equal(t, c.HexDigest(), gal.OAuthAccessTokenKeyID.String)
require.Equal(t, c.HexDigest(), gal.OAuthRefreshTokenKeyID.String)
}
userSecrets, err := db.ListUserSecretsWithValues(ctx, userID)
require.NoError(t, err, "failed to get user secrets for user %s", userID)
for _, s := range userSecrets {
requireEncryptedEquals(t, c, "value-"+userID.String(), s.Value)
require.Equal(t, c.HexDigest(), s.ValueKeyID.String)
}
}
// nullCipher is a dbcrypt.Cipher that does not encrypt or decrypt.
@@ -11,7 +11,7 @@ OPTIONS:
-O, --org string, $CODER_ORGANIZATION
Select which organization (uuid or name) to use.
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|workspace build transition|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
Columns to display in table output.
-i, --initiator string, $CODER_PROVISIONER_JOB_LIST_INITIATOR
-58
View File
@@ -96,34 +96,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
}
log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
userSecrets, err := cryptTx.ListUserSecretsWithValues(ctx, uid)
if err != nil {
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
}
for _, secret := range userSecrets {
if secret.ValueKeyID.Valid && secret.ValueKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping user secret", slog.F("user_id", uid), slog.F("secret_name", secret.Name), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := cryptTx.UpdateUserSecretByUserIDAndName(ctx, database.UpdateUserSecretByUserIDAndNameParams{
UserID: uid,
Name: secret.Name,
UpdateValue: true,
Value: secret.Value,
ValueKeyID: sql.NullString{}, // dbcrypt will re-encrypt
UpdateDescription: false,
Description: "",
UpdateEnvName: false,
EnvName: "",
UpdateFilePath: false,
FilePath: "",
}); err != nil {
return xerrors.Errorf("rotate user secret user_id=%s name=%s: %w", uid, secret.Name, err)
}
log.Debug(ctx, "rotated user secret", slog.F("user_id", uid), slog.F("secret_name", secret.Name), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
return nil
}, &database.TxOptions{
Isolation: sql.LevelRepeatableRead,
@@ -263,34 +235,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
}
log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
}
userSecrets, err := tx.ListUserSecretsWithValues(ctx, uid)
if err != nil {
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
}
for _, secret := range userSecrets {
if !secret.ValueKeyID.Valid {
log.Debug(ctx, "skipping user secret", slog.F("user_id", uid), slog.F("secret_name", secret.Name), slog.F("current", idx+1))
continue
}
if _, err := tx.UpdateUserSecretByUserIDAndName(ctx, database.UpdateUserSecretByUserIDAndNameParams{
UserID: uid,
Name: secret.Name,
UpdateValue: true,
Value: secret.Value,
ValueKeyID: sql.NullString{}, // clear the key ID
UpdateDescription: false,
Description: "",
UpdateEnvName: false,
EnvName: "",
UpdateFilePath: false,
FilePath: "",
}); err != nil {
return xerrors.Errorf("decrypt user secret user_id=%s name=%s: %w", uid, secret.Name, err)
}
log.Debug(ctx, "decrypted user secret", slog.F("user_id", uid), slog.F("secret_name", secret.Name), slog.F("current", idx+1))
}
return nil
}, &database.TxOptions{
Isolation: sql.LevelRepeatableRead,
@@ -348,8 +292,6 @@ DELETE FROM external_auth_links
OR oauth_refresh_token_key_id IS NOT NULL;
DELETE FROM user_chat_provider_keys
WHERE api_key_key_id IS NOT NULL;
DELETE FROM user_secrets
WHERE value_key_id IS NOT NULL;
UPDATE chat_providers
SET api_key = '',
api_key_key_id = NULL
-54
View File
@@ -717,60 +717,6 @@ func (db *dbCrypt) UpsertMCPServerUserToken(ctx context.Context, params database
return tok, nil
}
func (db *dbCrypt) CreateUserSecret(ctx context.Context, params database.CreateUserSecretParams) (database.UserSecret, error) {
if err := db.encryptField(&params.Value, &params.ValueKeyID); err != nil {
return database.UserSecret{}, err
}
secret, err := db.Store.CreateUserSecret(ctx, params)
if err != nil {
return database.UserSecret{}, err
}
if err := db.decryptField(&secret.Value, secret.ValueKeyID); err != nil {
return database.UserSecret{}, err
}
return secret, nil
}
func (db *dbCrypt) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
secret, err := db.Store.GetUserSecretByUserIDAndName(ctx, arg)
if err != nil {
return database.UserSecret{}, err
}
if err := db.decryptField(&secret.Value, secret.ValueKeyID); err != nil {
return database.UserSecret{}, err
}
return secret, nil
}
func (db *dbCrypt) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
secrets, err := db.Store.ListUserSecretsWithValues(ctx, userID)
if err != nil {
return nil, err
}
for i := range secrets {
if err := db.decryptField(&secrets[i].Value, secrets[i].ValueKeyID); err != nil {
return nil, err
}
}
return secrets, nil
}
func (db *dbCrypt) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
if arg.UpdateValue {
if err := db.encryptField(&arg.Value, &arg.ValueKeyID); err != nil {
return database.UserSecret{}, err
}
}
secret, err := db.Store.UpdateUserSecretByUserIDAndName(ctx, arg)
if err != nil {
return database.UserSecret{}, err
}
if err := db.decryptField(&secret.Value, secret.ValueKeyID); err != nil {
return database.UserSecret{}, err
}
return secret, nil
}
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
// If no cipher is loaded, then we can't encrypt anything!
if db.ciphers == nil || db.primaryCipherDigest == "" {
-195
View File
@@ -1287,198 +1287,3 @@ func TestUserChatProviderKeys(t *testing.T) {
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey)
})
}
func TestUserSecrets(t *testing.T) {
t.Parallel()
ctx := context.Background()
const (
//nolint:gosec // test credentials
initialValue = "super-secret-value-initial"
//nolint:gosec // test credentials
updatedValue = "super-secret-value-updated"
)
insertUserSecret := func(
t *testing.T,
crypt *dbCrypt,
ciphers []Cipher,
) database.UserSecret {
t.Helper()
user := dbgen.User(t, crypt, database.User{})
secret, err := crypt.CreateUserSecret(ctx, database.CreateUserSecretParams{
ID: uuid.New(),
UserID: user.ID,
Name: "test-secret-" + uuid.NewString()[:8],
Value: initialValue,
})
require.NoError(t, err)
require.Equal(t, initialValue, secret.Value)
if len(ciphers) > 0 {
require.Equal(t, ciphers[0].HexDigest(), secret.ValueKeyID.String)
}
return secret
}
t.Run("CreateUserSecretEncryptsValue", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
secret := insertUserSecret(t, crypt, ciphers)
// Reading through crypt should return plaintext.
got, err := crypt.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: secret.UserID,
Name: secret.Name,
})
require.NoError(t, err)
require.Equal(t, initialValue, got.Value)
// Reading through raw DB should return encrypted value.
raw, err := db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: secret.UserID,
Name: secret.Name,
})
require.NoError(t, err)
require.NotEqual(t, initialValue, raw.Value)
requireEncryptedEquals(t, ciphers[0], raw.Value, initialValue)
})
t.Run("ListUserSecretsWithValuesDecrypts", func(t *testing.T) {
t.Parallel()
_, crypt, ciphers := setup(t)
secret := insertUserSecret(t, crypt, ciphers)
secrets, err := crypt.ListUserSecretsWithValues(ctx, secret.UserID)
require.NoError(t, err)
require.Len(t, secrets, 1)
require.Equal(t, initialValue, secrets[0].Value)
})
t.Run("UpdateUserSecretReEncryptsValue", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
secret := insertUserSecret(t, crypt, ciphers)
updated, err := crypt.UpdateUserSecretByUserIDAndName(ctx, database.UpdateUserSecretByUserIDAndNameParams{
UserID: secret.UserID,
Name: secret.Name,
UpdateValue: true,
Value: updatedValue,
ValueKeyID: sql.NullString{},
})
require.NoError(t, err)
require.Equal(t, updatedValue, updated.Value)
require.Equal(t, ciphers[0].HexDigest(), updated.ValueKeyID.String)
// Raw DB should have new encrypted value.
raw, err := db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: secret.UserID,
Name: secret.Name,
})
require.NoError(t, err)
require.NotEqual(t, updatedValue, raw.Value)
requireEncryptedEquals(t, ciphers[0], raw.Value, updatedValue)
})
t.Run("NoCipherStoresPlaintext", func(t *testing.T) {
t.Parallel()
db, crypt := setupNoCiphers(t)
user := dbgen.User(t, crypt, database.User{})
secret, err := crypt.CreateUserSecret(ctx, database.CreateUserSecretParams{
ID: uuid.New(),
UserID: user.ID,
Name: "plaintext-secret",
Value: initialValue,
})
require.NoError(t, err)
require.Equal(t, initialValue, secret.Value)
require.False(t, secret.ValueKeyID.Valid)
// Raw DB should also have plaintext.
raw, err := db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: user.ID,
Name: "plaintext-secret",
})
require.NoError(t, err)
require.Equal(t, initialValue, raw.Value)
require.False(t, raw.ValueKeyID.Valid)
})
t.Run("UpdateMetadataOnlySkipsEncryption", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
secret := insertUserSecret(t, crypt, ciphers)
// Read the raw encrypted value from the database.
rawBefore, err := db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: secret.UserID,
Name: secret.Name,
})
require.NoError(t, err)
// Perform a metadata-only update (no value change).
updated, err := crypt.UpdateUserSecretByUserIDAndName(ctx, database.UpdateUserSecretByUserIDAndNameParams{
UserID: secret.UserID,
Name: secret.Name,
UpdateValue: false,
Value: "",
ValueKeyID: sql.NullString{},
UpdateDescription: true,
Description: "updated description",
UpdateEnvName: false,
EnvName: "",
UpdateFilePath: false,
FilePath: "",
})
require.NoError(t, err)
require.Equal(t, "updated description", updated.Description)
require.Equal(t, initialValue, updated.Value)
// Read the raw encrypted value again.
rawAfter, err := db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: secret.UserID,
Name: secret.Name,
})
require.NoError(t, err)
require.Equal(t, rawBefore.Value, rawAfter.Value)
require.Equal(t, rawBefore.ValueKeyID, rawAfter.ValueKeyID)
})
t.Run("GetUserSecretDecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
dbgen.UserSecret(t, db, database.UserSecret{
UserID: user.ID,
Name: "corrupt-secret",
Value: fakeBase64RandomData(t, 32),
ValueKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: user.ID,
Name: "corrupt-secret",
})
require.Error(t, err)
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr)
})
t.Run("ListUserSecretsWithValuesDecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
dbgen.UserSecret(t, db, database.UserSecret{
UserID: user.ID,
Name: "corrupt-list-secret",
Value: fakeBase64RandomData(t, 32),
ValueKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.ListUserSecretsWithValues(ctx, user.ID)
require.Error(t, err)
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr)
})
}
+3 -3
View File
@@ -518,7 +518,7 @@ require (
cloud.google.com/go/logging v1.13.2 // indirect
cloud.google.com/go/longrunning v0.8.0 // indirect
cloud.google.com/go/monitoring v1.24.3 // indirect
cloud.google.com/go/storage v1.61.3 // indirect
cloud.google.com/go/storage v1.60.0 // indirect
git.sr.ht/~jackmordaunt/go-toast v1.1.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
@@ -576,8 +576,8 @@ require (
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/google/go-containerregistry v0.20.7 // indirect
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.72 // indirect
github.com/hashicorp/go-getter v1.8.6 // indirect
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.70 // indirect
github.com/hashicorp/go-getter v1.8.4 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackmordaunt/icns/v3 v3.0.1 // indirect
+8 -8
View File
@@ -18,8 +18,8 @@ cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7
cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk=
cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE=
cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI=
cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg=
cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk=
cloud.google.com/go/storage v1.60.0 h1:oBfZrSOCimggVNz9Y/bXY35uUcts7OViubeddTTVzQ8=
cloud.google.com/go/storage v1.60.0/go.mod h1:q+5196hXfejkctrnx+VYU8RKQr/L3c0cBIlrjmiAKE0=
cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U=
cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s=
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
@@ -687,8 +687,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
github.com/hairyhenderson/go-codeowners v0.7.0 h1:s0W4wF8bdsBEjTWzwzSlsatSthWtTAF2xLgo4a4RwAo=
github.com/hairyhenderson/go-codeowners v0.7.0/go.mod h1:wUlNgQ3QjqC4z8DnM5nnCYVq/icpqXJyJOukKx5U8/Q=
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.72 h1:vTCWu1wbdYo7PEZFem/rlr01+Un+wwVmI7wiegFdRLk=
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.72/go.mod h1:Vn+BBgKQHVQYdVQ4NZDICE1Brb+JfaONyDHr3q07oQc=
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.70 h1:0HADrxxqaQkGycO1JoUUA+B4FnIkuo8d2bz/hSaTFFQ=
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.70/go.mod h1:fm2FdDCzJdtbXF7WKAMvBb5NEPouXPHFbGNYs9ShFns=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -698,8 +698,8 @@ github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9n
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-cty v1.5.0 h1:EkQ/v+dDNUqnuVpmS5fPqyY71NXVgT5gf32+57xY8g0=
github.com/hashicorp/go-cty v1.5.0/go.mod h1:lFUCG5kd8exDobgSfyj4ONE/dc822kiYMguVKdHGMLM=
github.com/hashicorp/go-getter v1.8.6 h1:9sQboWULaydVphxc4S64oAI4YqpuCk7nPmvbk131ebY=
github.com/hashicorp/go-getter v1.8.6/go.mod h1:nVH12eOV2P58dIiL3rsU6Fh3wLeJEKBOJzhMmzlSWoo=
github.com/hashicorp/go-getter v1.8.4 h1:hGEd2xsuVKgwkMtPVufq73fAmZU/x65PPcqH3cb0D9A=
github.com/hashicorp/go-getter v1.8.4/go.mod h1:x27pPGSg9kzoB147QXI8d/nDvp2IgYGcwuRjpaXE9Yg=
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
@@ -1322,8 +1322,8 @@ go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDO
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0/go.mod h1:EtekO9DEJb4/jRyN4v4Qjc2yA7AtfCBuz2FynRUWTXs=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 h1:aTL7F04bJHUlztTsNGJ2l+6he8c+y/b//eR0jjjemT4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0/go.mod h1:kldtb7jDTeol0l3ewcmd8SDvx3EmIE7lyvqbasU3QC4=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0 h1:ZrPRak/kS4xI3AVXy8F7pipuDXmDsrO8Lg+yQjBLjw0=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0/go.mod h1:3y6kQCWztq6hyW8Z9YxQDDm0Je9AJoFar2G0yDcmhRk=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0 h1:5gn2urDL/FBnK8OkCfD1j3/ER79rUuTYmCvlXBKeYL8=
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0/go.mod h1:0fBG6ZJxhqByfFZDwSwpZGzJU671HkwpWaNe2t4VUPI=
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.37.0 h1:SNhVp/9q4Go/XHBkQ1/d5u9P/U+L1yaGPoi0x+mStaI=
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.37.0/go.mod h1:tx8OOlGH6R4kLV67YaYO44GFXloEjGPZuMjEkaaqIp4=
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
+1 -1
View File
@@ -801,7 +801,7 @@ func (jfs justFilesSystem) Open(name string) (fs.File, error) {
type RenderOAuthAllowData struct {
AppIcon string
AppName string
CancelURI htmltemplate.URL
CancelURI string
RedirectURI string
CSRFToken string
Username string
+2 -2
View File
@@ -145,7 +145,7 @@ export const watchWorkspace = (
export const watchChat = (
chatId: string,
afterMessageId?: number,
): OneWayWebSocketApi<TypesGen.ChatStreamEvent[]> => {
): OneWayWebSocketApi<TypesGen.ServerSentEvent> => {
const params = new URLSearchParams();
if (afterMessageId !== undefined && afterMessageId > 0) {
params.set("after_id", afterMessageId.toString());
@@ -161,7 +161,7 @@ export const watchChat = (
});
};
export const watchChats = (): OneWayWebSocket<TypesGen.ChatWatchEvent> => {
export const watchChats = (): OneWayWebSocket<TypesGen.ServerSentEvent> => {
const searchParams: Record<string, string> = {};
const token = API.getSessionToken();
if (token) {
@@ -1,44 +0,0 @@
import { describe, expect, it } from "vitest";
import type * as TypesGen from "#/api/typesGenerated";
import { buildOptimisticEditedMessage } from "./chatMessageEdits";
const makeUserMessage = (
content: readonly TypesGen.ChatMessagePart[] = [
{ type: "text", text: "original" },
],
): TypesGen.ChatMessage => ({
id: 1,
chat_id: "chat-1",
created_at: "2025-01-01T00:00:00.000Z",
role: "user",
content,
});
describe("buildOptimisticEditedMessage", () => {
it("preserves image MIME types for newly attached files", () => {
const message = buildOptimisticEditedMessage({
requestContent: [{ type: "file", file_id: "image-1" }],
originalMessage: makeUserMessage(),
attachmentMediaTypes: new Map([["image-1", "image/png"]]),
});
expect(message.content).toEqual([
{ type: "file", file_id: "image-1", media_type: "image/png" },
]);
});
it("reuses existing file parts before local attachment metadata", () => {
const existingFilePart: TypesGen.ChatFilePart = {
type: "file",
file_id: "existing-1",
media_type: "image/jpeg",
};
const message = buildOptimisticEditedMessage({
requestContent: [{ type: "file", file_id: "existing-1" }],
originalMessage: makeUserMessage([existingFilePart]),
attachmentMediaTypes: new Map([["existing-1", "text/plain"]]),
});
expect(message.content).toEqual([existingFilePart]);
});
});
-148
View File
@@ -1,148 +0,0 @@
import type { InfiniteData } from "react-query";
import type * as TypesGen from "#/api/typesGenerated";
const buildOptimisticEditedContent = ({
requestContent,
originalMessage,
attachmentMediaTypes,
}: {
requestContent: readonly TypesGen.ChatInputPart[];
originalMessage: TypesGen.ChatMessage;
attachmentMediaTypes?: ReadonlyMap<string, string>;
}): readonly TypesGen.ChatMessagePart[] => {
const existingFilePartsByID = new Map<string, TypesGen.ChatFilePart>();
for (const part of originalMessage.content ?? []) {
if (part.type === "file" && part.file_id) {
existingFilePartsByID.set(part.file_id, part);
}
}
return requestContent.map((part): TypesGen.ChatMessagePart => {
if (part.type === "text") {
return { type: "text", text: part.text ?? "" };
}
if (part.type === "file-reference") {
return {
type: "file-reference",
file_name: part.file_name ?? "",
start_line: part.start_line ?? 1,
end_line: part.end_line ?? 1,
content: part.content ?? "",
};
}
const fileId = part.file_id ?? "";
return (
existingFilePartsByID.get(fileId) ?? {
type: "file",
file_id: part.file_id,
media_type:
attachmentMediaTypes?.get(fileId) ?? "application/octet-stream",
}
);
});
};
export const buildOptimisticEditedMessage = ({
requestContent,
originalMessage,
attachmentMediaTypes,
}: {
requestContent: readonly TypesGen.ChatInputPart[];
originalMessage: TypesGen.ChatMessage;
attachmentMediaTypes?: ReadonlyMap<string, string>;
}): TypesGen.ChatMessage => ({
...originalMessage,
content: buildOptimisticEditedContent({
requestContent,
originalMessage,
attachmentMediaTypes,
}),
});
const sortMessagesDescending = (
messages: readonly TypesGen.ChatMessage[],
): TypesGen.ChatMessage[] => [...messages].sort((a, b) => b.id - a.id);
const upsertFirstPageMessage = (
messages: readonly TypesGen.ChatMessage[],
message: TypesGen.ChatMessage,
): TypesGen.ChatMessage[] => {
const byID = new Map(
messages.map((existingMessage) => [existingMessage.id, existingMessage]),
);
byID.set(message.id, message);
return sortMessagesDescending(Array.from(byID.values()));
};
export const projectEditedConversationIntoCache = ({
currentData,
editedMessageId,
replacementMessage,
queuedMessages,
}: {
currentData: InfiniteData<TypesGen.ChatMessagesResponse> | undefined;
editedMessageId: number;
replacementMessage?: TypesGen.ChatMessage;
queuedMessages?: readonly TypesGen.ChatQueuedMessage[];
}): InfiniteData<TypesGen.ChatMessagesResponse> | undefined => {
if (!currentData?.pages?.length) {
return currentData;
}
const truncatedPages = currentData.pages.map((page, pageIndex) => {
const truncatedMessages = page.messages.filter(
(message) => message.id < editedMessageId,
);
const nextPage = {
...page,
...(pageIndex === 0 && queuedMessages !== undefined
? { queued_messages: queuedMessages }
: {}),
};
if (pageIndex !== 0 || !replacementMessage) {
return { ...nextPage, messages: truncatedMessages };
}
return {
...nextPage,
messages: upsertFirstPageMessage(truncatedMessages, replacementMessage),
};
});
return {
...currentData,
pages: truncatedPages,
};
};
export const reconcileEditedMessageInCache = ({
currentData,
optimisticMessageId,
responseMessage,
}: {
currentData: InfiniteData<TypesGen.ChatMessagesResponse> | undefined;
optimisticMessageId: number;
responseMessage: TypesGen.ChatMessage;
}): InfiniteData<TypesGen.ChatMessagesResponse> | undefined => {
if (!currentData?.pages?.length) {
return currentData;
}
const replacedPages = currentData.pages.map((page, pageIndex) => {
const preservedMessages = page.messages.filter(
(message) =>
message.id !== optimisticMessageId && message.id !== responseMessage.id,
);
if (pageIndex !== 0) {
return { ...page, messages: preservedMessages };
}
return {
...page,
messages: upsertFirstPageMessage(preservedMessages, responseMessage),
};
});
return {
...currentData,
pages: replacedPages,
};
};
+19 -177
View File
@@ -2,7 +2,6 @@ import { QueryClient } from "react-query";
import { describe, expect, it, vi } from "vitest";
import { API } from "#/api/api";
import type * as TypesGen from "#/api/typesGenerated";
import { buildOptimisticEditedMessage } from "./chatMessageEdits";
import {
archiveChat,
cancelChatListRefetches,
@@ -796,44 +795,14 @@ describe("mutation invalidation scope", () => {
content: [{ type: "text" as const, text: `msg ${id}` }],
});
const makeQueuedMessage = (
chatId: string,
id: number,
): TypesGen.ChatQueuedMessage => ({
id,
chat_id: chatId,
created_at: `2025-01-01T00:10:${String(id).padStart(2, "0")}Z`,
content: [{ type: "text" as const, text: `queued ${id}` }],
});
const editReq = {
content: [{ type: "text" as const, text: "edited" }],
};
const requireMessage = (
messages: readonly TypesGen.ChatMessage[],
messageId: number,
): TypesGen.ChatMessage => {
const message = messages.find((candidate) => candidate.id === messageId);
if (!message) {
throw new Error(`missing message ${messageId}`);
}
return message;
};
const buildOptimisticMessage = (message: TypesGen.ChatMessage) =>
buildOptimisticEditedMessage({
originalMessage: message,
requestContent: editReq.content,
});
it("editChatMessage writes the optimistic replacement into cache", async () => {
it("editChatMessage optimistically removes truncated messages from cache", async () => {
const queryClient = createTestQueryClient();
const chatId = "chat-1";
const messages = [5, 4, 3, 2, 1].map((id) => makeMsg(chatId, id));
const optimisticMessage = buildOptimisticMessage(
requireMessage(messages, 3),
);
queryClient.setQueryData<InfMessages>(chatMessagesKey(chatId), {
pages: [{ messages, queued_messages: [], has_more: false }],
@@ -843,58 +812,18 @@ describe("mutation invalidation scope", () => {
const mutation = editChatMessage(queryClient, chatId);
const context = await mutation.onMutate({
messageId: 3,
optimisticMessage,
req: editReq,
});
const data = queryClient.getQueryData<InfMessages>(chatMessagesKey(chatId));
expect(data?.pages[0]?.messages.map((message) => message.id)).toEqual([
3, 2, 1,
]);
expect(data?.pages[0]?.messages[0]?.content).toEqual(
optimisticMessage.content,
);
expect(data?.pages[0]?.messages.map((m) => m.id)).toEqual([2, 1]);
expect(context?.previousData?.pages[0]?.messages).toHaveLength(5);
});
it("editChatMessage clears queued messages in cache during optimistic history edit", async () => {
const queryClient = createTestQueryClient();
const chatId = "chat-1";
const messages = [5, 4, 3, 2, 1].map((id) => makeMsg(chatId, id));
const optimisticMessage = buildOptimisticMessage(
requireMessage(messages, 3),
);
const queuedMessages = [makeQueuedMessage(chatId, 11)];
queryClient.setQueryData<InfMessages>(chatMessagesKey(chatId), {
pages: [
{
messages,
queued_messages: queuedMessages,
has_more: false,
},
],
pageParams: [undefined],
});
const mutation = editChatMessage(queryClient, chatId);
await mutation.onMutate({
messageId: 3,
optimisticMessage,
req: editReq,
});
const data = queryClient.getQueryData<InfMessages>(chatMessagesKey(chatId));
expect(data?.pages[0]?.queued_messages).toEqual([]);
});
it("editChatMessage restores cache on error", async () => {
const queryClient = createTestQueryClient();
const chatId = "chat-1";
const messages = [5, 4, 3, 2, 1].map((id) => makeMsg(chatId, id));
const optimisticMessage = buildOptimisticMessage(
requireMessage(messages, 3),
);
queryClient.setQueryData<InfMessages>(chatMessagesKey(chatId), {
pages: [{ messages, queued_messages: [], has_more: false }],
@@ -904,85 +833,22 @@ describe("mutation invalidation scope", () => {
const mutation = editChatMessage(queryClient, chatId);
const context = await mutation.onMutate({
messageId: 3,
optimisticMessage,
req: editReq,
});
expect(
queryClient.getQueryData<InfMessages>(chatMessagesKey(chatId))?.pages[0]
?.messages,
).toHaveLength(3);
).toHaveLength(2);
mutation.onError(
new Error("network failure"),
{ messageId: 3, optimisticMessage, req: editReq },
{ messageId: 3, req: editReq },
context,
);
const data = queryClient.getQueryData<InfMessages>(chatMessagesKey(chatId));
expect(data?.pages[0]?.messages.map((message) => message.id)).toEqual([
5, 4, 3, 2, 1,
]);
});
it("editChatMessage preserves websocket-upserted newer messages on success", async () => {
const queryClient = createTestQueryClient();
const chatId = "chat-1";
const messages = [5, 4, 3, 2, 1].map((id) => makeMsg(chatId, id));
const optimisticMessage = buildOptimisticMessage(
requireMessage(messages, 3),
);
const responseMessage = {
...makeMsg(chatId, 9),
content: [{ type: "text" as const, text: "edited authoritative" }],
};
const websocketMessage = {
...makeMsg(chatId, 10),
content: [{ type: "text" as const, text: "assistant follow-up" }],
role: "assistant" as const,
};
queryClient.setQueryData<InfMessages>(chatMessagesKey(chatId), {
pages: [{ messages, queued_messages: [], has_more: false }],
pageParams: [undefined],
});
const mutation = editChatMessage(queryClient, chatId);
await mutation.onMutate({
messageId: 3,
optimisticMessage,
req: editReq,
});
queryClient.setQueryData<InfMessages | undefined>(
chatMessagesKey(chatId),
(current) => {
if (!current) {
return current;
}
return {
...current,
pages: [
{
...current.pages[0],
messages: [websocketMessage, ...current.pages[0].messages],
},
...current.pages.slice(1),
],
};
},
);
mutation.onSuccess(
{ message: responseMessage },
{ messageId: 3, optimisticMessage, req: editReq },
);
const data = queryClient.getQueryData<InfMessages>(chatMessagesKey(chatId));
expect(data?.pages[0]?.messages.map((message) => message.id)).toEqual([
10, 9, 2, 1,
]);
expect(data?.pages[0]?.messages[1]?.content).toEqual(
responseMessage.content,
);
expect(data?.pages[0]?.messages.map((m) => m.id)).toEqual([5, 4, 3, 2, 1]);
});
it("editChatMessage onMutate is a no-op when cache is empty", async () => {
@@ -1024,14 +890,13 @@ describe("mutation invalidation scope", () => {
expect(data?.pages[0]?.messages.map((m) => m.id)).toEqual([3, 2, 1]);
});
it("editChatMessage onMutate updates the first page and preserves older pages", async () => {
it("editChatMessage onMutate filters across multiple pages", async () => {
const queryClient = createTestQueryClient();
const chatId = "chat-1";
// Page 0 (newest): IDs 106. Page 1 (older): IDs 51.
const page0 = [10, 9, 8, 7, 6].map((id) => makeMsg(chatId, id));
const page1 = [5, 4, 3, 2, 1].map((id) => makeMsg(chatId, id));
const optimisticMessage = buildOptimisticMessage(requireMessage(page0, 7));
queryClient.setQueryData<InfMessages>(chatMessagesKey(chatId), {
pages: [
@@ -1042,28 +907,19 @@ describe("mutation invalidation scope", () => {
});
const mutation = editChatMessage(queryClient, chatId);
await mutation.onMutate({
messageId: 7,
optimisticMessage,
req: editReq,
});
await mutation.onMutate({ messageId: 7, req: editReq });
const data = queryClient.getQueryData<InfMessages>(chatMessagesKey(chatId));
expect(data?.pages[0]?.messages.map((message) => message.id)).toEqual([
7, 6,
]);
expect(data?.pages[1]?.messages.map((message) => message.id)).toEqual([
5, 4, 3, 2, 1,
]);
// Page 0: only ID 6 survives (< 7).
expect(data?.pages[0]?.messages.map((m) => m.id)).toEqual([6]);
// Page 1: all survive (all < 7).
expect(data?.pages[1]?.messages.map((m) => m.id)).toEqual([5, 4, 3, 2, 1]);
});
it("editChatMessage onMutate keeps the optimistic replacement when editing the first message", async () => {
it("editChatMessage onMutate editing the first message empties all pages", async () => {
const queryClient = createTestQueryClient();
const chatId = "chat-1";
const messages = [5, 4, 3, 2, 1].map((id) => makeMsg(chatId, id));
const optimisticMessage = buildOptimisticMessage(
requireMessage(messages, 1),
);
queryClient.setQueryData<InfMessages>(chatMessagesKey(chatId), {
pages: [{ messages, queued_messages: [], has_more: false }],
@@ -1071,25 +927,20 @@ describe("mutation invalidation scope", () => {
});
const mutation = editChatMessage(queryClient, chatId);
await mutation.onMutate({
messageId: 1,
optimisticMessage,
req: editReq,
});
await mutation.onMutate({ messageId: 1, req: editReq });
const data = queryClient.getQueryData<InfMessages>(chatMessagesKey(chatId));
expect(data?.pages[0]?.messages.map((message) => message.id)).toEqual([1]);
// All messages have id >= 1, so the page is empty.
expect(data?.pages[0]?.messages).toHaveLength(0);
// Sibling fields survive the spread.
expect(data?.pages[0]?.queued_messages).toEqual([]);
expect(data?.pages[0]?.has_more).toBe(false);
});
it("editChatMessage onMutate keeps earlier messages when editing the latest message", async () => {
it("editChatMessage onMutate editing the latest message keeps earlier ones", async () => {
const queryClient = createTestQueryClient();
const chatId = "chat-1";
const messages = [5, 4, 3, 2, 1].map((id) => makeMsg(chatId, id));
const optimisticMessage = buildOptimisticMessage(
requireMessage(messages, 5),
);
queryClient.setQueryData<InfMessages>(chatMessagesKey(chatId), {
pages: [{ messages, queued_messages: [], has_more: false }],
@@ -1097,19 +948,10 @@ describe("mutation invalidation scope", () => {
});
const mutation = editChatMessage(queryClient, chatId);
await mutation.onMutate({
messageId: 5,
optimisticMessage,
req: editReq,
});
await mutation.onMutate({ messageId: 5, req: editReq });
const data = queryClient.getQueryData<InfMessages>(chatMessagesKey(chatId));
expect(data?.pages[0]?.messages.map((message) => message.id)).toEqual([
5, 4, 3, 2, 1,
]);
expect(data?.pages[0]?.messages[0]?.content).toEqual(
optimisticMessage.content,
);
expect(data?.pages[0]?.messages.map((m) => m.id)).toEqual([4, 3, 2, 1]);
});
it("interruptChat does not invalidate unrelated queries", async () => {
+27 -36
View File
@@ -6,10 +6,6 @@ import type {
import { API } from "#/api/api";
import type * as TypesGen from "#/api/typesGenerated";
import type { UsePaginatedQueryOptions } from "#/hooks/usePaginatedQuery";
import {
projectEditedConversationIntoCache,
reconcileEditedMessageInCache,
} from "./chatMessageEdits";
export const chatsKey = ["chats"] as const;
export const chatKey = (chatId: string) => ["chats", chatId] as const;
@@ -605,21 +601,13 @@ export const createChatMessage = (
type EditChatMessageMutationArgs = {
messageId: number;
optimisticMessage?: TypesGen.ChatMessage;
req: TypesGen.EditChatMessageRequest;
};
type EditChatMessageMutationContext = {
previousData?: InfiniteData<TypesGen.ChatMessagesResponse> | undefined;
};
export const editChatMessage = (queryClient: QueryClient, chatId: string) => ({
mutationFn: ({ messageId, req }: EditChatMessageMutationArgs) =>
API.experimental.editChatMessage(chatId, messageId, req),
onMutate: async ({
messageId,
optimisticMessage,
}: EditChatMessageMutationArgs): Promise<EditChatMessageMutationContext> => {
onMutate: async ({ messageId }: EditChatMessageMutationArgs) => {
// Cancel in-flight refetches so they don't overwrite the
// optimistic update before the mutation completes.
await queryClient.cancelQueries({
@@ -631,23 +619,40 @@ export const editChatMessage = (queryClient: QueryClient, chatId: string) => ({
InfiniteData<TypesGen.ChatMessagesResponse>
>(chatMessagesKey(chatId));
// Optimistically remove the edited message and everything
// after it. The server soft-deletes these and inserts a
// replacement with a new ID. Without this, the WebSocket
// handler's upsertCacheMessages adds new messages to the
// React Query cache without removing the soft-deleted ones,
// causing deleted messages to flash back into view until
// the full REST refetch resolves.
queryClient.setQueryData<
InfiniteData<TypesGen.ChatMessagesResponse> | undefined
>(chatMessagesKey(chatId), (current) =>
projectEditedConversationIntoCache({
currentData: current,
editedMessageId: messageId,
replacementMessage: optimisticMessage,
queuedMessages: [],
}),
);
>(chatMessagesKey(chatId), (current) => {
if (!current?.pages?.length) {
return current;
}
return {
...current,
pages: current.pages.map((page) => ({
...page,
messages: page.messages.filter((m) => m.id < messageId),
})),
};
});
return { previousData };
},
onError: (
_error: unknown,
_variables: EditChatMessageMutationArgs,
context: EditChatMessageMutationContext | undefined,
context:
| {
previousData?:
| InfiniteData<TypesGen.ChatMessagesResponse>
| undefined;
}
| undefined,
) => {
// Restore the cache on failure so the user sees the
// original messages again.
@@ -655,20 +660,6 @@ export const editChatMessage = (queryClient: QueryClient, chatId: string) => ({
queryClient.setQueryData(chatMessagesKey(chatId), context.previousData);
}
},
onSuccess: (
response: TypesGen.EditChatMessageResponse,
variables: EditChatMessageMutationArgs,
) => {
queryClient.setQueryData<
InfiniteData<TypesGen.ChatMessagesResponse> | undefined
>(chatMessagesKey(chatId), (current) =>
reconcileEditedMessageInCache({
currentData: current,
optimisticMessageId: variables.messageId,
responseMessage: response.message,
}),
);
},
onSettled: () => {
// Always reconcile with the server regardless of whether
// the mutation succeeded or failed. On success this picks
-119
View File
@@ -1,119 +0,0 @@
import { describe, expect, it, vi } from "vitest";
import { API } from "#/api/api";
import type { AuthorizationCheck, Organization } from "#/api/typesGenerated";
import { permittedOrganizations } from "./organizations";
// Mock the API module
vi.mock("#/api/api", () => ({
API: {
getOrganizations: vi.fn(),
checkAuthorization: vi.fn(),
},
}));
const MockOrg1: Organization = {
id: "org-1",
name: "org-one",
display_name: "Org One",
description: "",
icon: "",
created_at: "",
updated_at: "",
is_default: true,
};
const MockOrg2: Organization = {
id: "org-2",
name: "org-two",
display_name: "Org Two",
description: "",
icon: "",
created_at: "",
updated_at: "",
is_default: false,
};
const templateCreateCheck: AuthorizationCheck = {
object: { resource_type: "template" },
action: "create",
};
describe("permittedOrganizations", () => {
it("returns query config with correct queryKey", () => {
const config = permittedOrganizations(templateCreateCheck);
expect(config.queryKey).toEqual([
"organizations",
"permitted",
templateCreateCheck,
]);
});
it("fetches orgs and filters by permission check", async () => {
const getOrgsMock = vi.mocked(API.getOrganizations);
const checkAuthMock = vi.mocked(API.checkAuthorization);
getOrgsMock.mockResolvedValue([MockOrg1, MockOrg2]);
checkAuthMock.mockResolvedValue({
"org-1": true,
"org-2": false,
});
const config = permittedOrganizations(templateCreateCheck);
const result = await config.queryFn!();
// Should only return org-1 (which passed the check)
expect(result).toEqual([MockOrg1]);
// Verify the auth check was called with per-org checks
expect(checkAuthMock).toHaveBeenCalledWith({
checks: {
"org-1": {
...templateCreateCheck,
object: {
...templateCreateCheck.object,
organization_id: "org-1",
},
},
"org-2": {
...templateCreateCheck,
object: {
...templateCreateCheck.object,
organization_id: "org-2",
},
},
},
});
});
it("returns all orgs when all pass the check", async () => {
const getOrgsMock = vi.mocked(API.getOrganizations);
const checkAuthMock = vi.mocked(API.checkAuthorization);
getOrgsMock.mockResolvedValue([MockOrg1, MockOrg2]);
checkAuthMock.mockResolvedValue({
"org-1": true,
"org-2": true,
});
const config = permittedOrganizations(templateCreateCheck);
const result = await config.queryFn!();
expect(result).toEqual([MockOrg1, MockOrg2]);
});
it("returns empty array when no orgs pass the check", async () => {
const getOrgsMock = vi.mocked(API.getOrganizations);
const checkAuthMock = vi.mocked(API.checkAuthorization);
getOrgsMock.mockResolvedValue([MockOrg1, MockOrg2]);
checkAuthMock.mockResolvedValue({
"org-1": false,
"org-2": false,
});
const config = permittedOrganizations(templateCreateCheck);
const result = await config.queryFn!();
expect(result).toEqual([]);
});
});
+1 -27
View File
@@ -5,7 +5,6 @@ import {
type GetProvisionerJobsParams,
} from "#/api/api";
import type {
AuthorizationCheck,
CreateOrganizationRequest,
GroupSyncSettings,
Organization,
@@ -161,7 +160,7 @@ export const updateOrganizationMemberRoles = (
};
};
const organizationsKey = ["organizations"] as const;
export const organizationsKey = ["organizations"] as const;
const notAvailable = { available: false, value: undefined } as const;
@@ -296,31 +295,6 @@ export const provisionerJobs = (
};
};
/**
* Fetch organizations the current user is permitted to use for a given
* action. Fetches all organizations, runs a per-org authorization
* check, and returns only those that pass.
*/
export const permittedOrganizations = (check: AuthorizationCheck) => {
return {
queryKey: ["organizations", "permitted", check],
queryFn: async (): Promise<Organization[]> => {
const orgs = await API.getOrganizations();
const checks = Object.fromEntries(
orgs.map((org) => [
org.id,
{
...check,
object: { ...check.object, organization_id: org.id },
},
]),
);
const permissions = await API.checkAuthorization({ checks });
return orgs.filter((org) => permissions[org.id]);
},
};
};
/**
* Fetch permissions for all provided organizations.
*
-1
View File
@@ -5615,7 +5615,6 @@ export interface ProvisionerJobMetadata {
readonly template_icon: string;
readonly workspace_id?: string;
readonly workspace_name?: string;
readonly workspace_build_transition?: WorkspaceTransition;
}
// From codersdk/provisionerdaemons.go
@@ -1,14 +1,18 @@
import type { Meta, StoryObj } from "@storybook/react-vite";
import { expect, fn, screen, userEvent, waitFor, within } from "storybook/test";
import { MockOrganization, MockOrganization2 } from "#/testHelpers/entities";
import { action } from "storybook/actions";
import { userEvent, within } from "storybook/test";
import {
MockOrganization,
MockOrganization2,
MockUserOwner,
} from "#/testHelpers/entities";
import { OrganizationAutocomplete } from "./OrganizationAutocomplete";
const meta: Meta<typeof OrganizationAutocomplete> = {
title: "components/OrganizationAutocomplete",
component: OrganizationAutocomplete,
args: {
onChange: fn(),
options: [MockOrganization, MockOrganization2],
onChange: action("Selected organization"),
},
};
@@ -16,51 +20,36 @@ export default meta;
type Story = StoryObj<typeof OrganizationAutocomplete>;
export const ManyOrgs: Story = {
args: {
value: null,
parameters: {
showOrganizations: true,
user: MockUserOwner,
features: ["multiple_organizations"],
permissions: { viewDeploymentConfig: true },
queries: [
{
key: ["organizations"],
data: [MockOrganization, MockOrganization2],
},
],
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const button = canvas.getByRole("button");
await userEvent.click(button);
await waitFor(() => {
expect(
screen.getByText(MockOrganization.display_name),
).toBeInTheDocument();
expect(
screen.getByText(MockOrganization2.display_name),
).toBeInTheDocument();
});
},
};
export const WithValue: Story = {
args: {
value: MockOrganization2,
},
play: async ({ canvasElement, args }) => {
const canvas = within(canvasElement);
await waitFor(() => {
expect(
canvas.getByText(MockOrganization2.display_name),
).toBeInTheDocument();
});
expect(args.onChange).not.toHaveBeenCalled();
},
};
export const OneOrg: Story = {
args: {
value: MockOrganization,
options: [MockOrganization],
},
play: async ({ canvasElement, args }) => {
const canvas = within(canvasElement);
await waitFor(() => {
expect(
canvas.getByText(MockOrganization.display_name),
).toBeInTheDocument();
});
expect(args.onChange).not.toHaveBeenCalled();
parameters: {
showOrganizations: true,
user: MockUserOwner,
features: ["multiple_organizations"],
permissions: { viewDeploymentConfig: true },
queries: [
{
key: ["organizations"],
data: [MockOrganization],
},
],
},
};
@@ -1,6 +1,9 @@
import { Check } from "lucide-react";
import { type FC, useState } from "react";
import type { Organization } from "#/api/typesGenerated";
import { type FC, useEffect, useState } from "react";
import { useQuery } from "react-query";
import { checkAuthorization } from "#/api/queries/authCheck";
import { organizations } from "#/api/queries/organizations";
import type { AuthorizationCheck, Organization } from "#/api/typesGenerated";
import { ChevronDownIcon } from "#/components/AnimatedIcons/ChevronDown";
import { Avatar } from "#/components/Avatar/Avatar";
import { Button } from "#/components/Button/Button";
@@ -19,21 +22,62 @@ import {
} from "#/components/Popover/Popover";
type OrganizationAutocompleteProps = {
value: Organization | null;
onChange: (organization: Organization | null) => void;
options: Organization[];
id?: string;
required?: boolean;
check?: AuthorizationCheck;
};
export const OrganizationAutocomplete: FC<OrganizationAutocompleteProps> = ({
value,
onChange,
options,
id,
required,
check,
}) => {
const [open, setOpen] = useState(false);
const [selected, setSelected] = useState<Organization | null>(null);
const organizationsQuery = useQuery(organizations());
const checks =
check &&
organizationsQuery.data &&
Object.fromEntries(
organizationsQuery.data.map((org) => [
org.id,
{
...check,
object: { ...check.object, organization_id: org.id },
},
]),
);
const permissionsQuery = useQuery({
...checkAuthorization({ checks: checks ?? {} }),
enabled: Boolean(check && organizationsQuery.data),
});
// If an authorization check was provided, filter the organizations based on
// the results of that check.
let options = organizationsQuery.data ?? [];
if (check) {
options = permissionsQuery.data
? options.filter((org) => permissionsQuery.data[org.id])
: [];
}
// Unfortunate: this useEffect sets a default org value
// if only one is available and is necessary as the autocomplete loads
// its own data. Until we refactor, proceed cautiously!
useEffect(() => {
const org = options[0];
if (options.length !== 1 || org === selected) {
return;
}
setSelected(org);
onChange(org);
}, [options, selected, onChange]);
return (
<Popover open={open} onOpenChange={setOpen}>
@@ -46,14 +90,14 @@ export const OrganizationAutocomplete: FC<OrganizationAutocompleteProps> = ({
data-testid="organization-autocomplete"
className="w-full justify-start gap-2 font-normal"
>
{value ? (
{selected ? (
<>
<Avatar
size="sm"
src={value.icon}
fallback={value.display_name}
src={selected.icon}
fallback={selected.display_name}
/>
<span className="truncate">{value.display_name}</span>
<span className="truncate">{selected.display_name}</span>
</>
) : (
<span className="text-content-secondary">
@@ -77,6 +121,7 @@ export const OrganizationAutocomplete: FC<OrganizationAutocompleteProps> = ({
key={org.id}
value={`${org.display_name} ${org.name}`}
onSelect={() => {
setSelected(org);
onChange(org);
setOpen(false);
}}
@@ -89,7 +134,7 @@ export const OrganizationAutocomplete: FC<OrganizationAutocompleteProps> = ({
<span className="truncate">
{org.display_name || org.name}
</span>
{value?.id === org.id && (
{selected?.id === org.id && (
<Check className="ml-auto size-icon-sm shrink-0" />
)}
</CommandItem>
+1 -2
View File
@@ -30,8 +30,7 @@ export function useTime<T>(func: () => T, options: UseTimeOptions = {}): T {
}
const handle = setInterval(() => {
const next = thunk();
setComputedValue(() => next);
setComputedValue(() => thunk());
}, interval);
return () => {
@@ -1,5 +1,5 @@
import type { Meta, StoryObj } from "@storybook/react-vite";
import { expect, screen, spyOn, userEvent, within } from "storybook/test";
import { screen, spyOn, userEvent, within } from "storybook/test";
import { API } from "#/api/api";
import { getPreferredProxy } from "#/contexts/ProxyContext";
import { chromatic } from "#/testHelpers/chromatic";
@@ -57,13 +57,6 @@ export const HasError: Story = {
agent: undefined,
},
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const moreActionsButton = canvas.getByRole("button", {
name: "Dev Container actions",
});
expect(moreActionsButton).toBeVisible();
},
};
export const NoPorts: Story = {};
@@ -130,13 +123,6 @@ export const NoContainerOrSubAgent: Story = {
},
subAgents: [],
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const moreActionsButton = canvas.getByRole("button", {
name: "Dev Container actions",
});
expect(moreActionsButton).toBeVisible();
},
};
export const NoContainerOrAgentOrName: Story = {
@@ -274,7 +274,7 @@ export const AgentDevcontainerCard: FC<AgentDevcontainerCardProps> = ({
/>
)}
{!isTransitioning && (
{showDevcontainerControls && (
<AgentDevcontainerMoreActions
deleteDevContainer={deleteDevcontainerMutation.mutate}
/>
@@ -89,7 +89,7 @@ export const WorkspaceBuildLogs: FC<WorkspaceBuildLogsProps> = ({
<div
className={cn(
"logs-header",
"flex items-center border-solid border-0 border-b last:border-b-0 border-border font-sans",
"flex items-center border-solid border-0 border-b border-border font-sans",
"bg-surface-primary text-xs font-semibold leading-none",
"first-of-type:pt-4",
)}
@@ -100,24 +100,22 @@ export const WorkspaceTimings: FC<WorkspaceTimingsProps> = ({
return (
<div className="rounded-lg border-solid bg-surface-primary">
<div className="flex items-center justify-between px-4 py-1.5 relative">
<Button
disabled={isLoading}
variant="subtle"
onClick={() => setIsOpen((o) => !o)}
className="after:content-[''] after:absolute after:inset-0"
>
<ChevronDownIcon open={isOpen} />
<span>Build timeline</span>
</Button>
<span className="ml-auto text-sm text-content-secondary pr-2">
<Button
disabled={isLoading}
variant="subtle"
className="w-full flex items-center"
onClick={() => setIsOpen((o) => !o)}
>
<ChevronDownIcon open={isOpen} className="size-4 mr-4" />
<span>Build timeline</span>
<span className="ml-auto text-content-secondary">
{isLoading ? (
<Skeleton variant="text" width={40} height={16} />
) : (
displayProvisioningTime()
)}
</span>
</div>
</Button>
{!isLoading && (
<Collapse in={isOpen}>
<div
@@ -198,6 +198,14 @@ const buildQueries = (
];
};
/**
* Wrap a chat stream event payload in the JSON string format that
* OneWayWebSocket expects when receiving a WebSocket message event.
* The result is a `ServerSentEvent` of type `"data"` serialised to JSON.
*/
const wrapSSE = (payload: unknown): string =>
JSON.stringify({ type: "data", data: payload });
// ---------------------------------------------------------------------------
// Meta
// ---------------------------------------------------------------------------
@@ -848,20 +856,17 @@ export const StreamedSubagentTitle: Story = {
"/chats/": [
{
event: "message",
data: JSON.stringify([
{
type: "message_part",
chat_id: CHAT_ID,
message_part: {
part: {
type: "tool-call",
tool_call_id: "tool-subagent-stream-1",
tool_name: "spawn_agent",
args_delta: '{"title":"Streamed Child"',
},
data: wrapSSE({
type: "message_part",
message_part: {
part: {
type: "tool-call",
tool_call_id: "tool-subagent-stream-1",
tool_name: "spawn_agent",
args_delta: '{"title":"Streamed Child"',
},
},
] satisfies TypesGen.ChatStreamEvent[]),
}),
},
],
},
@@ -1145,18 +1150,15 @@ export const StreamedReasoning: Story = {
"/chats/": [
{
event: "message",
data: JSON.stringify([
{
type: "message_part",
chat_id: CHAT_ID,
message_part: {
part: {
type: "reasoning",
text: "Streaming reasoning body",
},
data: wrapSSE({
type: "message_part",
message_part: {
part: {
type: "reasoning",
text: "Streaming reasoning body",
},
},
] satisfies TypesGen.ChatStreamEvent[]),
}),
},
],
},
@@ -1228,20 +1230,18 @@ export const WithWaitAgentComputerUseVNC: Story = {
"/chats/": [
{
event: "message",
data: JSON.stringify([
{
type: "message_part",
chat_id: CHAT_ID,
message_part: {
part: {
type: "tool-call",
tool_call_id: "tool-wait-desktop",
tool_name: "wait_agent",
args_delta: '{"chat_id":"desktop-child-1"}',
},
data: wrapSSE({
type: "message_part",
chat_id: CHAT_ID,
message_part: {
part: {
type: "tool-call",
tool_call_id: "tool-wait-desktop",
tool_name: "wait_agent",
args_delta: '{"chat_id":"desktop-child-1"}',
},
},
] satisfies TypesGen.ChatStreamEvent[]),
}),
},
],
},
@@ -4,12 +4,9 @@ import { beforeEach, describe, expect, it, vi } from "vitest";
import {
draftInputStorageKeyPrefix,
getPersistedDraftInputValue,
restoreOptimisticRequestSnapshot,
useConversationEditingState,
} from "./AgentChatPage";
import type { ChatMessageInputRef } from "./components/AgentChatInput";
import { createChatStore } from "./components/ChatConversation/chatStore";
import type { PendingAttachment } from "./components/ChatPageContent";
type MockChatInputHandle = {
handle: ChatMessageInputRef;
@@ -87,41 +84,6 @@ describe("getPersistedDraftInputValue", () => {
});
});
describe("restoreOptimisticRequestSnapshot", () => {
it("restores queued messages, stream output, status, and stream error", () => {
const store = createChatStore();
store.setQueuedMessages([
{
id: 9,
chat_id: "chat-abc-123",
created_at: "2025-01-01T00:00:00.000Z",
content: [{ type: "text" as const, text: "queued" }],
},
]);
store.setChatStatus("running");
store.applyMessagePart({ type: "text", text: "partial response" });
store.setStreamError({ kind: "generic", message: "old error" });
const previousSnapshot = store.getSnapshot();
store.batch(() => {
store.setQueuedMessages([]);
store.setChatStatus("pending");
store.clearStreamState();
store.clearStreamError();
});
restoreOptimisticRequestSnapshot(store, previousSnapshot);
const restoredSnapshot = store.getSnapshot();
expect(restoredSnapshot.queuedMessages).toEqual(
previousSnapshot.queuedMessages,
);
expect(restoredSnapshot.chatStatus).toBe(previousSnapshot.chatStatus);
expect(restoredSnapshot.streamState).toBe(previousSnapshot.streamState);
expect(restoredSnapshot.streamError).toEqual(previousSnapshot.streamError);
});
});
describe("useConversationEditingState", () => {
const chatID = "chat-abc-123";
const expectedKey = `${draftInputStorageKeyPrefix}${chatID}`;
@@ -365,64 +327,6 @@ describe("useConversationEditingState", () => {
unmount();
});
it("forwards pending attachments through history-edit send", async () => {
const { result, onSend, unmount } = renderEditing();
const attachments: PendingAttachment[] = [
{ fileId: "file-1", mediaType: "image/png" },
];
act(() => {
result.current.handleEditUserMessage(7, "hello");
});
await act(async () => {
await result.current.handleSendFromInput("hello", attachments);
});
expect(onSend).toHaveBeenCalledWith("hello", attachments, 7);
unmount();
});
it("restores the edit draft and file-block seed when an edit submission fails", async () => {
const { result, onSend, unmount } = renderEditing();
const mockInput = createMockChatInputHandle("edited message");
const fileBlocks = [
{ type: "file", file_id: "file-1", media_type: "image/png" },
] as const;
result.current.chatInputRef.current = mockInput.handle;
onSend.mockRejectedValueOnce(new Error("boom"));
const editorState = JSON.stringify({
root: {
children: [
{
children: [{ text: "edited message" }],
type: "paragraph",
},
],
type: "root",
},
});
act(() => {
result.current.handleEditUserMessage(7, "edited message", fileBlocks);
result.current.handleContentChange("edited message", editorState, false);
});
await act(async () => {
await expect(
result.current.handleSendFromInput("edited message"),
).rejects.toThrow("boom");
});
expect(mockInput.clear).toHaveBeenCalled();
expect(result.current.inputValueRef.current).toBe("edited message");
expect(result.current.editingMessageId).toBe(7);
expect(result.current.editingFileBlocks).toEqual(fileBlocks);
expect(result.current.editorInitialValue).toBe("edited message");
expect(result.current.initialEditorState).toBe(editorState);
unmount();
});
it("clears the composer and persisted draft after a successful send", async () => {
localStorage.setItem(expectedKey, "draft to clear");
const { result, onSend, unmount } = renderEditing();
+30 -136
View File
@@ -11,7 +11,6 @@ import { toast } from "sonner";
import type { UrlTransform } from "streamdown";
import { API, watchWorkspace } from "#/api/api";
import { isApiError } from "#/api/errors";
import { buildOptimisticEditedMessage } from "#/api/queries/chatMessageEdits";
import {
chat,
chatDesktopEnabled,
@@ -52,14 +51,11 @@ import {
getWorkspaceAgent,
} from "./components/ChatConversation/chatHelpers";
import {
type ChatStore,
type ChatStoreState,
selectChatStatus,
useChatSelector,
useChatStore,
} from "./components/ChatConversation/chatStore";
import { useWorkspaceCreationWatcher } from "./components/ChatConversation/useWorkspaceCreationWatcher";
import type { PendingAttachment } from "./components/ChatPageContent";
import {
getDefaultMCPSelection,
getSavedMCPSelection,
@@ -105,47 +101,12 @@ export function getPersistedDraftInputValue(
).text;
}
/** @internal Exported for testing. */
export const restoreOptimisticRequestSnapshot = (
store: Pick<
ChatStore,
| "batch"
| "setChatStatus"
| "setQueuedMessages"
| "setStreamError"
| "setStreamState"
>,
snapshot: Pick<
ChatStoreState,
"chatStatus" | "queuedMessages" | "streamError" | "streamState"
>,
): void => {
store.batch(() => {
store.setQueuedMessages(snapshot.queuedMessages);
store.setChatStatus(snapshot.chatStatus);
store.setStreamState(snapshot.streamState);
store.setStreamError(snapshot.streamError);
});
};
const buildAttachmentMediaTypes = (
attachments?: readonly PendingAttachment[],
): ReadonlyMap<string, string> | undefined => {
if (!attachments?.length) {
return undefined;
}
return new Map(
attachments.map(({ fileId, mediaType }) => [fileId, mediaType]),
);
};
/** @internal Exported for testing. */
export function useConversationEditingState(deps: {
chatID: string | undefined;
onSend: (
message: string,
attachments?: readonly PendingAttachment[],
fileIds?: string[],
editedMessageID?: number,
) => Promise<void>;
onDeleteQueuedMessage: (id: number) => Promise<void>;
@@ -169,9 +130,6 @@ export function useConversationEditingState(deps: {
};
},
);
const serializedEditorStateRef = useRef<string | undefined>(
initialEditorState,
);
// Monotonic counter to force LexicalComposer remount.
const [remountKey, setRemountKey] = useState(0);
@@ -218,7 +176,6 @@ export function useConversationEditingState(deps: {
editorInitialValue: text,
initialEditorState: undefined,
});
serializedEditorStateRef.current = undefined;
setRemountKey((k) => k + 1);
inputValueRef.current = text;
setEditingFileBlocks(fileBlocks ?? []);
@@ -231,7 +188,6 @@ export function useConversationEditingState(deps: {
editorInitialValue: savedText,
initialEditorState: savedState,
});
serializedEditorStateRef.current = savedState;
setRemountKey((k) => k + 1);
inputValueRef.current = savedText;
setEditingMessageId(null);
@@ -265,7 +221,6 @@ export function useConversationEditingState(deps: {
editorInitialValue: text,
initialEditorState: undefined,
});
serializedEditorStateRef.current = undefined;
setRemountKey((k) => k + 1);
inputValueRef.current = text;
setEditingFileBlocks(fileBlocks);
@@ -278,7 +233,6 @@ export function useConversationEditingState(deps: {
editorInitialValue: savedText,
initialEditorState: savedState,
});
serializedEditorStateRef.current = savedState;
setRemountKey((k) => k + 1);
inputValueRef.current = savedText;
setEditingQueuedMessageID(null);
@@ -286,48 +240,25 @@ export function useConversationEditingState(deps: {
setEditingFileBlocks([]);
};
// Clears the composer for an in-flight history edit and
// returns a rollback function that restores the editing draft
// if the send fails.
const clearInputForHistoryEdit = (message: string) => {
const snapshot = {
editorState: serializedEditorStateRef.current,
fileBlocks: editingFileBlocks,
messageId: editingMessageId,
};
// Wraps the parent onSend to clear local input/editing state
// and handle queue-edit deletion.
const handleSendFromInput = async (message: string, fileIds?: string[]) => {
const editedMessageID =
editingMessageId !== null ? editingMessageId : undefined;
const queueEditID = editingQueuedMessageID;
chatInputRef.current?.clear();
inputValueRef.current = "";
setEditingMessageId(null);
return () => {
setDraftState({
editorInitialValue: message,
initialEditorState: snapshot.editorState,
});
serializedEditorStateRef.current = snapshot.editorState;
setRemountKey((k) => k + 1);
inputValueRef.current = message;
setEditingMessageId(snapshot.messageId);
setEditingFileBlocks(snapshot.fileBlocks);
};
};
// Clears all input and editing state after a successful send.
const finalizeSuccessfulSend = (
editedMessageID: number | undefined,
queueEditID: number | null,
) => {
await onSend(message, fileIds, editedMessageID);
// Clear input and editing state on success.
chatInputRef.current?.clear();
if (!isMobileViewport()) {
chatInputRef.current?.focus();
}
inputValueRef.current = "";
serializedEditorStateRef.current = undefined;
if (draftStorageKey) {
localStorage.removeItem(draftStorageKey);
}
if (editedMessageID !== undefined) {
if (editingMessageId !== null) {
setEditingMessageId(null);
setDraftBeforeHistoryEdit(null);
setEditingFileBlocks([]);
}
@@ -339,41 +270,12 @@ export function useConversationEditingState(deps: {
}
};
// Wraps the parent onSend to clear local input/editing state
// and handle queue-edit deletion.
const handleSendFromInput = async (
message: string,
attachments?: readonly PendingAttachment[],
) => {
const editedMessageID =
editingMessageId !== null ? editingMessageId : undefined;
const queueEditID = editingQueuedMessageID;
const sendPromise = onSend(message, attachments, editedMessageID);
// For history edits, clear input immediately and prepare
// a rollback in case the send fails.
const rollback =
editedMessageID !== undefined
? clearInputForHistoryEdit(message)
: undefined;
try {
await sendPromise;
} catch (error) {
rollback?.();
throw error;
}
finalizeSuccessfulSend(editedMessageID, queueEditID);
};
const handleContentChange = (
content: string,
serializedEditorState: string,
hasFileReferences: boolean,
) => {
inputValueRef.current = content;
serializedEditorStateRef.current = serializedEditorState;
// Don't overwrite the persisted draft while editing a
// history or queued message — the original draft (possibly
@@ -528,6 +430,9 @@ const AgentChatPage: FC = () => {
} = useOutletContext<AgentsOutletContext>();
const queryClient = useQueryClient();
const [selectedModel, setSelectedModel] = useState("");
const [pendingEditMessageId, setPendingEditMessageId] = useState<
number | null
>(null);
const scrollToBottomRef = useRef<(() => void) | null>(null);
const chatInputRef = useRef<ChatMessageInputRef | null>(null);
const inputValueRef = useRef(
@@ -870,7 +775,7 @@ const AgentChatPage: FC = () => {
const handleSend = async (
message: string,
attachments?: readonly PendingAttachment[],
fileIds?: string[],
editedMessageID?: number,
) => {
const chatInputHandle = (
@@ -885,9 +790,7 @@ const AgentChatPage: FC = () => {
(p) => p.type === "file-reference",
);
const hasContent =
message.trim() ||
(attachments && attachments.length > 0) ||
hasFileReferences;
message.trim() || (fileIds && fileIds.length > 0) || hasFileReferences;
if (!hasContent || isSubmissionPending || !agentId || !hasModelOptions) {
return;
}
@@ -915,41 +818,28 @@ const AgentChatPage: FC = () => {
}
}
// Add pre-uploaded file attachments.
if (attachments && attachments.length > 0) {
for (const { fileId } of attachments) {
// Add pre-uploaded file references.
if (fileIds && fileIds.length > 0) {
for (const fileId of fileIds) {
content.push({ type: "file", file_id: fileId });
}
}
if (editedMessageID !== undefined) {
const request: TypesGen.EditChatMessageRequest = { content };
const originalEditedMessage = chatMessagesList?.find(
(existingMessage) => existingMessage.id === editedMessageID,
);
const optimisticMessage = originalEditedMessage
? buildOptimisticEditedMessage({
requestContent: request.content,
originalMessage: originalEditedMessage,
attachmentMediaTypes: buildAttachmentMediaTypes(attachments),
})
: undefined;
const previousSnapshot = store.getSnapshot();
clearChatErrorReason(agentId);
clearStreamError();
store.batch(() => {
store.setQueuedMessages([]);
store.setChatStatus("running");
store.clearStreamState();
});
setPendingEditMessageId(editedMessageID);
scrollToBottomRef.current?.();
try {
await editMessage({
messageId: editedMessageID,
optimisticMessage,
req: request,
});
store.clearStreamState();
store.setChatStatus("running");
setPendingEditMessageId(null);
} catch (error) {
restoreOptimisticRequestSnapshot(store, previousSnapshot);
setPendingEditMessageId(null);
handleUsageLimitError(error);
throw error;
}
@@ -1028,8 +918,10 @@ const AgentChatPage: FC = () => {
const handlePromoteQueuedMessage = async (id: number) => {
const previousSnapshot = store.getSnapshot();
const previousQueuedMessages = previousSnapshot.queuedMessages;
const previousChatStatus = previousSnapshot.chatStatus;
store.setQueuedMessages(
previousSnapshot.queuedMessages.filter((message) => message.id !== id),
previousQueuedMessages.filter((message) => message.id !== id),
);
store.clearStreamState();
if (agentId) {
@@ -1045,7 +937,8 @@ const AgentChatPage: FC = () => {
store.upsertDurableMessage(promotedMessage);
upsertCacheMessages([promotedMessage]);
} catch (error) {
restoreOptimisticRequestSnapshot(store, previousSnapshot);
store.setQueuedMessages(previousQueuedMessages);
store.setChatStatus(previousChatStatus);
handleUsageLimitError(error);
throw error;
}
@@ -1240,6 +1133,7 @@ const AgentChatPage: FC = () => {
workspaceAgent={workspaceAgent}
store={store}
editing={editing}
pendingEditMessageId={pendingEditMessageId}
effectiveSelectedModel={effectiveSelectedModel}
setSelectedModel={setSelectedModel}
modelOptions={modelOptions}
@@ -113,6 +113,7 @@ const StoryAgentChatPageView: FC<StoryProps> = ({ editing, ...overrides }) => {
parentChat: undefined as TypesGen.Chat | undefined,
isArchived: false,
store: createChatStore(),
pendingEditMessageId: null as number | null,
effectiveSelectedModel: defaultModelConfigID,
setSelectedModel: fn(),
modelOptions: defaultModelOptions,
@@ -504,6 +505,22 @@ export const EditingMessage: Story = {
),
};
/** The saving state while an edit is in progress shows the pending
* indicator on the message being saved. */
export const EditingSaving: Story = {
render: () => (
<StoryAgentChatPageView
store={buildStoreWithMessages(editingMessages)}
editing={{
editingMessageId: 3,
editorInitialValue: "Now tell me a better joke",
}}
pendingEditMessageId={3}
isSubmissionPending
/>
),
};
// ---------------------------------------------------------------------------
// AgentChatPageNotFoundView stories
// ---------------------------------------------------------------------------
@@ -36,7 +36,6 @@ import {
import type { useChatStore } from "./components/ChatConversation/chatStore";
import type { ModelSelectorOption } from "./components/ChatElements";
import { DesktopPanelContext } from "./components/ChatElements/tools/DesktopPanelContext";
import type { PendingAttachment } from "./components/ChatPageContent";
import { ChatPageInput, ChatPageTimeline } from "./components/ChatPageContent";
import { ChatScrollContainer } from "./components/ChatScrollContainer";
import { ChatTopBar } from "./components/ChatTopBar";
@@ -70,10 +69,7 @@ interface EditingState {
fileBlocks: readonly ChatMessagePart[],
) => void;
handleCancelQueueEdit: () => void;
handleSendFromInput: (
message: string,
attachments?: readonly PendingAttachment[],
) => void;
handleSendFromInput: (message: string, fileIds?: string[]) => void;
handleContentChange: (
content: string,
serializedEditorState: string,
@@ -96,6 +92,7 @@ interface AgentChatPageViewProps {
// Editing state.
editing: EditingState;
pendingEditMessageId: number | null;
// Model/input configuration.
effectiveSelectedModel: string;
@@ -182,6 +179,7 @@ export const AgentChatPageView: FC<AgentChatPageViewProps> = ({
workspace,
store,
editing,
pendingEditMessageId,
effectiveSelectedModel,
setSelectedModel,
modelOptions,
@@ -389,6 +387,7 @@ export const AgentChatPageView: FC<AgentChatPageViewProps> = ({
persistedError={persistedError}
onEditUserMessage={editing.handleEditUserMessage}
editingMessageId={editing.editingMessageId}
savingMessageId={pendingEditMessageId}
urlTransform={urlTransform}
mcpServers={mcpServers}
/>
+22 -1
View File
@@ -51,6 +51,7 @@ import {
chatDetailErrorsEqual,
} from "./utils/usageLimitMessage";
// Type guard for SSE events from the chat list watch endpoint.
// Shallow-compare two ChatDiffStatus objects by their meaningful
// fields, ignoring refreshed_at/stale_at which change on every poll.
function diffStatusEqual(
@@ -74,6 +75,19 @@ function diffStatusEqual(
);
}
function isChatListSSEEvent(
data: unknown,
): data is { kind: string; chat: TypesGen.Chat } {
if (typeof data !== "object" || data === null) return false;
const obj = data as Record<string, unknown>;
return (
typeof obj.kind === "string" &&
typeof obj.chat === "object" &&
obj.chat !== null &&
"id" in obj.chat
);
}
export type { AgentsOutletContext } from "./AgentsPageView";
const AgentsPage: FC = () => {
@@ -481,7 +495,14 @@ const AgentsPage: FC = () => {
console.warn("Failed to parse chat watch event:", event.parseError);
return;
}
const chatEvent = event.parsedMessage;
const sse = event.parsedMessage;
if (sse?.type !== "data" || !sse.data) {
return;
}
if (!isChatListSSEEvent(sse.data)) {
return;
}
const chatEvent = sse.data;
const updatedChat = chatEvent.chat;
// Read the previous status from the infinite chat list
// cache before we write the update below. The per-chat
@@ -668,8 +668,9 @@ export const AgentChatInput: FC<AgentChatInputProps> = ({
<div className="flex items-center justify-between border-b border-border-warning/50 px-3 py-1.5">
<span className="flex items-center gap-1.5 text-xs font-medium text-content-warning">
<PencilIcon className="h-3.5 w-3.5" />
Editing will delete all subsequent messages and restart the
conversation here.
{isLoading
? "Saving edit..."
: "Editing will delete all subsequent messages and restart the conversation here."}
</span>
<Button
type="button"
@@ -939,59 +939,3 @@ export const ThinkingOnlyAssistantSpacing: Story = {
expect(canvas.getByText("Any progress?")).toBeInTheDocument();
},
};
/**
* Regression: sources-only assistant messages must have consistent
* bottom spacing before the next user bubble. A spacer div fills the
* gap that would normally come from the hidden action bar.
*/
export const SourcesOnlyAssistantSpacing: Story = {
args: {
...defaultArgs,
parsedMessages: buildMessages([
{
...baseMessage,
id: 1,
role: "user",
content: [{ type: "text", text: "Can you share your sources?" }],
},
{
...baseMessage,
id: 2,
role: "assistant",
content: [
{
type: "source",
url: "https://example.com/docs",
title: "Documentation",
},
{
type: "source",
url: "https://example.com/api",
title: "API Reference",
},
],
},
{
...baseMessage,
id: 3,
role: "user",
content: [{ type: "text", text: "Thanks!" }],
},
]),
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.getByText("Can you share your sources?")).toBeInTheDocument();
expect(canvas.getByText("Thanks!")).toBeInTheDocument();
await userEvent.click(
canvas.getByRole("button", { name: /searched 2 results/i }),
);
expect(
canvas.getByRole("link", { name: "Documentation" }),
).toBeInTheDocument();
expect(
canvas.getByRole("link", { name: "API Reference" }),
).toBeInTheDocument();
},
};
@@ -12,6 +12,7 @@ import type { UrlTransform } from "streamdown";
import type * as TypesGen from "#/api/typesGenerated";
import { Button } from "#/components/Button/Button";
import { CopyButton } from "#/components/CopyButton/CopyButton";
import { Spinner } from "#/components/Spinner/Spinner";
import {
Tooltip,
TooltipContent,
@@ -426,6 +427,7 @@ const ChatMessageItem = memo<{
fileBlocks?: readonly TypesGen.ChatMessagePart[],
) => void;
editingMessageId?: number | null;
savingMessageId?: number | null;
isAfterEditingMessage?: boolean;
hideActions?: boolean;
@@ -444,6 +446,7 @@ const ChatMessageItem = memo<{
parsed,
onEditUserMessage,
editingMessageId,
savingMessageId,
isAfterEditingMessage = false,
hideActions = false,
fadeFromBottom = false,
@@ -455,6 +458,7 @@ const ChatMessageItem = memo<{
showDesktopPreviews,
}) => {
const isUser = message.role === "user";
const isSavingMessage = savingMessageId === message.id;
const [previewImage, setPreviewImage] = useState<string | null>(null);
const [previewText, setPreviewText] = useState<string | null>(null);
if (
@@ -512,11 +516,6 @@ const ChatMessageItem = memo<{
userInlineContent.length > 0 || Boolean(parsed.markdown?.trim());
const hasFileBlocks = userFileBlocks.length > 0;
const hasCopyableContent = Boolean(parsed.markdown.trim());
const needsAssistantBottomSpacer =
!hideActions &&
!isUser &&
!hasCopyableContent &&
(Boolean(parsed.reasoning) || parsed.sources.length > 0);
const conversationItemProps: { role: "user" | "assistant" } = {
role: isUser ? "user" : "assistant",
@@ -537,6 +536,7 @@ const ChatMessageItem = memo<{
"rounded-lg border border-solid border-border-default bg-surface-secondary px-3 py-2 font-sans shadow-sm transition-shadow",
editingMessageId === message.id &&
"border-surface-secondary shadow-[0_0_0_2px_hsla(var(--border-warning),0.6)]",
isSavingMessage && "ring-2 ring-content-secondary/40",
fadeFromBottom && "relative overflow-hidden",
)}
style={
@@ -567,6 +567,13 @@ const ChatMessageItem = memo<{
: parsed.markdown || ""}
</span>
)}
{isSavingMessage && (
<Spinner
className="mt-0.5 h-3.5 w-3.5 shrink-0 text-content-secondary"
aria-label="Saving message edit"
loading
/>
)}
</div>
)}
{hasFileBlocks && (
@@ -663,9 +670,12 @@ const ChatMessageItem = memo<{
</div>
)}
{/* Spacer for assistant messages without an action bar
(e.g. reasoning-only or sources-only) so they have
consistent bottom padding before the next user bubble. */}
{needsAssistantBottomSpacer && <div className="min-h-6" />}
(e.g. thinking-only) so they have consistent bottom
padding before the next user bubble. */}
{!hideActions &&
!isUser &&
!hasCopyableContent &&
Boolean(parsed.reasoning) && <div className="min-h-6" />}
{previewImage && (
<ImageLightbox
src={previewImage}
@@ -692,6 +702,7 @@ const StickyUserMessage = memo<{
fileBlocks?: readonly TypesGen.ChatMessagePart[],
) => void;
editingMessageId?: number | null;
savingMessageId?: number | null;
isAfterEditingMessage?: boolean;
}>(
({
@@ -699,6 +710,7 @@ const StickyUserMessage = memo<{
parsed,
onEditUserMessage,
editingMessageId,
savingMessageId,
isAfterEditingMessage = false,
}) => {
const [isStuck, setIsStuck] = useState(false);
@@ -923,6 +935,7 @@ const StickyUserMessage = memo<{
parsed={parsed}
onEditUserMessage={handleEditUserMessage}
editingMessageId={editingMessageId}
savingMessageId={savingMessageId}
isAfterEditingMessage={isAfterEditingMessage}
/>
</div>
@@ -965,6 +978,7 @@ const StickyUserMessage = memo<{
parsed={parsed}
onEditUserMessage={handleEditUserMessage}
editingMessageId={editingMessageId}
savingMessageId={savingMessageId}
isAfterEditingMessage={isAfterEditingMessage}
fadeFromBottom
/>
@@ -986,6 +1000,7 @@ interface ConversationTimelineProps {
fileBlocks?: readonly TypesGen.ChatMessagePart[],
) => void;
editingMessageId?: number | null;
savingMessageId?: number | null;
urlTransform?: UrlTransform;
mcpServers?: readonly TypesGen.MCPServerConfig[];
computerUseSubagentIds?: Set<string>;
@@ -999,6 +1014,7 @@ export const ConversationTimeline = memo<ConversationTimelineProps>(
subagentTitles,
onEditUserMessage,
editingMessageId,
savingMessageId,
urlTransform,
mcpServers,
computerUseSubagentIds,
@@ -1025,7 +1041,7 @@ export const ConversationTimeline = memo<ConversationTimelineProps>(
}
return (
<div data-testid="conversation-timeline" className="flex flex-col gap-2">
<div className="flex flex-col gap-2">
{parsedMessages.map(({ message, parsed }, msgIdx) => {
if (message.role === "user") {
return (
@@ -1035,6 +1051,7 @@ export const ConversationTimeline = memo<ConversationTimelineProps>(
parsed={parsed}
onEditUserMessage={onEditUserMessage}
editingMessageId={editingMessageId}
savingMessageId={savingMessageId}
isAfterEditingMessage={afterEditingMessageIds.has(message.id)}
/>
);
@@ -1048,6 +1065,7 @@ export const ConversationTimeline = memo<ConversationTimelineProps>(
key={message.id}
message={message}
parsed={parsed}
savingMessageId={savingMessageId}
urlTransform={urlTransform}
isAfterEditingMessage={afterEditingMessageIds.has(message.id)}
hideActions={!isLastInChain}
@@ -70,7 +70,7 @@ export const LiveStreamTailContent = ({
}
return (
<div className="flex flex-col gap-2">
<div className="flex flex-col gap-3">
{shouldRenderEmptyState && (
<div className="py-12 text-center text-content-secondary">
<p className="text-sm">Start a conversation with your agent.</p>
@@ -190,27 +190,6 @@ describe("setChatStatus", () => {
});
});
// ---------------------------------------------------------------------------
// setStreamState
// ---------------------------------------------------------------------------
describe("setStreamState", () => {
it("does not notify when setting the same stream state reference", () => {
const store = createChatStore();
store.applyMessagePart({ type: "text", text: "hello" });
const streamState = store.getSnapshot().streamState;
expect(streamState).not.toBeNull();
let notified = false;
store.subscribe(() => {
notified = true;
});
store.setStreamState(streamState);
expect(notified).toBe(false);
});
});
// ---------------------------------------------------------------------------
// setStreamError / clearStreamError
// ---------------------------------------------------------------------------
@@ -55,7 +55,7 @@ vi.mock("#/api/api", () => ({
}));
type MessageListener = (
payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]>,
payload: OneWayMessageEvent<TypesGen.ServerSentEvent>,
) => void;
type ErrorListener = (payload: Event) => void;
type OpenListener = (payload: Event) => void;
@@ -67,7 +67,6 @@ type MockSocketHelpers = {
emitOpen: () => void;
emitData: (event: TypesGen.ChatStreamEvent) => void;
emitDataBatch: (events: readonly TypesGen.ChatStreamEvent[]) => void;
emitParseError: () => void;
emitError: () => void;
emitClose: () => void;
};
@@ -144,30 +143,26 @@ const createMockSocket = (): MockSocket => {
removeEventListener,
close: vi.fn(),
emitData: (event) => {
const payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]> = {
const payload: OneWayMessageEvent<TypesGen.ServerSentEvent> = {
sourceEvent: {} as MessageEvent<string>,
parseError: undefined,
parsedMessage: [event],
parsedMessage: {
type: "data",
data: event,
},
};
for (const listener of messageListeners) {
listener(payload);
}
},
emitDataBatch: (events) => {
const payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]> = {
const payload: OneWayMessageEvent<TypesGen.ServerSentEvent> = {
sourceEvent: {} as MessageEvent<string>,
parseError: undefined,
parsedMessage: events as TypesGen.ChatStreamEvent[],
};
for (const listener of messageListeners) {
listener(payload);
}
},
emitParseError: () => {
const payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]> = {
sourceEvent: {} as MessageEvent<string>,
parseError: new Error("bad json"),
parsedMessage: undefined,
parsedMessage: {
type: "data",
data: events,
},
};
for (const listener of messageListeners) {
listener(payload);
@@ -4213,306 +4208,4 @@ describe("store/cache desync protection", () => {
expect(result.current.orderedMessageIDs).toEqual([1]);
});
});
it("reflects optimistic and authoritative history-edit cache updates through the normal sync effect", async () => {
immediateAnimationFrame();
const chatID = "chat-local-edit-sync";
const msg1 = makeMessage(chatID, 1, "user", "first");
const msg2 = makeMessage(chatID, 2, "assistant", "second");
const msg3 = makeMessage(chatID, 3, "user", "third");
const optimisticReplacement = {
...msg3,
content: [{ type: "text" as const, text: "edited draft" }],
};
const authoritativeReplacement = makeMessage(chatID, 9, "user", "edited");
const mockSocket = createMockSocket();
mockWatchChatReturn(mockSocket);
const queryClient = createTestQueryClient();
const wrapper: FC<PropsWithChildren> = ({ children }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
);
const initialOptions = {
chatID,
chatMessages: [msg1, msg2, msg3],
chatRecord: makeChat(chatID),
chatMessagesData: {
messages: [msg1, msg2, msg3],
queued_messages: [],
has_more: false,
},
chatQueuedMessages: [] as TypesGen.ChatQueuedMessage[],
setChatErrorReason: vi.fn(),
clearChatErrorReason: vi.fn(),
};
const { result, rerender } = renderHook(
(options: Parameters<typeof useChatStore>[0]) => {
const { store } = useChatStore(options);
return {
store,
messagesByID: useChatSelector(store, selectMessagesByID),
orderedMessageIDs: useChatSelector(store, selectOrderedMessageIDs),
};
},
{ initialProps: initialOptions, wrapper },
);
await waitFor(() => {
expect(result.current.orderedMessageIDs).toEqual([1, 2, 3]);
});
act(() => {
mockSocket.emitOpen();
});
rerender({
...initialOptions,
chatMessages: [msg1, msg2, optimisticReplacement],
chatMessagesData: {
messages: [msg1, msg2, optimisticReplacement],
queued_messages: [],
has_more: false,
},
});
await waitFor(() => {
expect(result.current.orderedMessageIDs).toEqual([1, 2, 3]);
expect(result.current.messagesByID.get(3)?.content).toEqual(
optimisticReplacement.content,
);
});
rerender({
...initialOptions,
chatMessages: [msg1, msg2, authoritativeReplacement],
chatMessagesData: {
messages: [msg1, msg2, authoritativeReplacement],
queued_messages: [],
has_more: false,
},
});
await waitFor(() => {
expect(result.current.orderedMessageIDs).toEqual([1, 2, 9]);
expect(result.current.messagesByID.has(3)).toBe(false);
expect(result.current.messagesByID.get(9)?.content).toEqual(
authoritativeReplacement.content,
);
});
});
});
describe("parse errors", () => {
it("surfaces parseError as streamError", async () => {
immediateAnimationFrame();
const chatID = "chat-parse-error";
const mockSocket = createMockSocket();
mockWatchChatReturn(mockSocket);
const queryClient = createTestQueryClient();
const wrapper = ({ children }: PropsWithChildren) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
);
const setChatErrorReason = vi.fn();
const clearChatErrorReason = vi.fn();
const { result } = renderHook(
() => {
const { store } = useChatStore({
chatID,
chatMessages: [],
chatRecord: makeChat(chatID),
chatMessagesData: {
messages: [],
queued_messages: [],
has_more: false,
},
chatQueuedMessages: [],
setChatErrorReason,
clearChatErrorReason,
});
return {
streamError: useChatSelector(store, selectStreamError),
chatStatus: useChatSelector(store, selectChatStatus),
};
},
{ wrapper },
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
act(() => {
mockSocket.emitParseError();
});
await waitFor(() => {
expect(result.current.streamError).toEqual({
kind: "generic",
message: "Failed to parse chat stream update.",
});
});
expect(result.current.chatStatus).not.toBe("error");
});
it("does not corrupt in-progress stream state", async () => {
immediateAnimationFrame();
const chatID = "chat-parse-no-corrupt";
const existingMessage = makeMessage(chatID, 1, "user", "hello");
const mockSocket = createMockSocket();
mockWatchChatReturn(mockSocket);
const queryClient = createTestQueryClient();
const wrapper = ({ children }: PropsWithChildren) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
);
const setChatErrorReason = vi.fn();
const clearChatErrorReason = vi.fn();
const { result } = renderHook(
() => {
const { store } = useChatStore({
chatID,
chatMessages: [existingMessage],
chatRecord: makeChat(chatID),
chatMessagesData: {
messages: [existingMessage],
queued_messages: [],
has_more: false,
},
chatQueuedMessages: [],
setChatErrorReason,
clearChatErrorReason,
});
return {
streamState: useChatSelector(store, selectStreamState),
streamError: useChatSelector(store, selectStreamError),
};
},
{ wrapper },
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
// Build up some stream state first.
act(() => {
mockSocket.emitData({
type: "message_part",
chat_id: chatID,
message_part: {
role: "assistant",
part: { type: "text", text: "partial response" },
},
});
});
await waitFor(() => {
expect(result.current.streamState?.blocks).toEqual([
{ type: "response", text: "partial response" },
]);
});
// Fire a parse error and verify the existing stream blocks survive.
act(() => {
mockSocket.emitParseError();
});
await waitFor(() => {
expect(result.current.streamError).toEqual({
kind: "generic",
message: "Failed to parse chat stream update.",
});
});
expect(result.current.streamState?.blocks).toEqual([
{ type: "response", text: "partial response" },
]);
});
it("continues processing after parse error", async () => {
immediateAnimationFrame();
const chatID = "chat-parse-recover";
const existingMessage = makeMessage(chatID, 1, "user", "hello");
const mockSocket = createMockSocket();
mockWatchChatReturn(mockSocket);
const queryClient = createTestQueryClient();
const wrapper = ({ children }: PropsWithChildren) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
);
const setChatErrorReason = vi.fn();
const clearChatErrorReason = vi.fn();
const { result } = renderHook(
() => {
const { store } = useChatStore({
chatID,
chatMessages: [existingMessage],
chatRecord: makeChat(chatID),
chatMessagesData: {
messages: [existingMessage],
queued_messages: [],
has_more: false,
},
chatQueuedMessages: [],
setChatErrorReason,
clearChatErrorReason,
});
return {
streamState: useChatSelector(store, selectStreamState),
streamError: useChatSelector(store, selectStreamError),
};
},
{ wrapper },
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
// Trigger a parse error first.
act(() => {
mockSocket.emitParseError();
});
await waitFor(() => {
expect(result.current.streamError).toEqual({
kind: "generic",
message: "Failed to parse chat stream update.",
});
});
// Send a valid message_part after the parse error.
act(() => {
mockSocket.emitData({
type: "message_part",
chat_id: chatID,
message_part: {
role: "assistant",
part: { type: "text", text: "recovered" },
},
});
});
// The stream should process the new part normally.
await waitFor(() => {
expect(result.current.streamState?.blocks).toEqual([
{ type: "response", text: "recovered" },
]);
});
// streamError is sticky and is not cleared by valid messages.
expect(result.current.streamError).toEqual({
kind: "generic",
message: "Failed to parse chat stream update.",
});
});
});
@@ -174,7 +174,6 @@ export type ChatStore = {
queuedMessages: readonly TypesGen.ChatQueuedMessage[] | undefined,
) => void;
setChatStatus: (status: TypesGen.ChatStatus | null) => void;
setStreamState: (streamState: StreamState | null) => void;
setStreamError: (reason: ChatDetailError | null) => void;
clearStreamError: () => void;
setRetryState: (state: RetryState | null) => void;
@@ -413,20 +412,6 @@ export const createChatStore = (): ChatStore => {
chatStatus: status,
}));
},
setStreamState: (streamState) => {
if (state.streamState === streamState) {
return;
}
setState((current) => {
if (current.streamState === streamState) {
return current;
}
return {
...current,
streamState,
};
});
},
setStreamError: (reason) => {
setState((current) => {
if (chatDetailErrorsEqual(current.streamError, reason)) {
@@ -6,6 +6,7 @@ import type * as TypesGen from "#/api/typesGenerated";
import type { OneWayMessageEvent } from "#/utils/OneWayWebSocket";
import { createReconnectingWebSocket } from "#/utils/reconnectingWebSocket";
import type { ChatDetailError } from "../../utils/usageLimitMessage";
import { asNumber, asString } from "../ChatElements/runtimeTypeUtils";
import {
type ChatStore,
type ChatStoreState,
@@ -16,24 +17,50 @@ import {
} from "./chatStore";
import type { RetryState } from "./types";
const isChatStreamEvent = (data: unknown): data is TypesGen.ChatStreamEvent =>
typeof data === "object" &&
data !== null &&
"type" in data &&
typeof (data as Record<string, unknown>).type === "string";
const isChatStreamEventArray = (
data: unknown,
): data is TypesGen.ChatStreamEvent[] =>
Array.isArray(data) && data.every(isChatStreamEvent);
const toChatStreamEvents = (data: unknown): TypesGen.ChatStreamEvent[] => {
if (isChatStreamEvent(data)) {
return [data];
}
if (isChatStreamEventArray(data)) {
return data;
}
return [];
};
const normalizeChatDetailError = (
error: TypesGen.ChatStreamError | undefined,
error: TypesGen.ChatStreamError | Record<string, unknown> | undefined,
): ChatDetailError => ({
message: error?.message.trim() || "Chat processing failed.",
kind: error?.kind?.trim() || "generic",
provider: error?.provider?.trim() || undefined,
retryable: error?.retryable,
statusCode: error?.status_code,
message: asString(error?.message).trim() || "Chat processing failed.",
kind: asString(error?.kind).trim() || "generic",
provider: asString(error?.provider).trim() || undefined,
retryable:
typeof error?.retryable === "boolean" ? error.retryable : undefined,
statusCode: asNumber(error?.status_code),
});
const normalizeRetryState = (retry: TypesGen.ChatStreamRetry): RetryState => ({
attempt: Math.max(1, retry.attempt),
error: retry.error.trim() || "Retrying request shortly.",
kind: retry.kind?.trim() || "generic",
provider: retry.provider?.trim() || undefined,
delayMs: retry.delay_ms,
retryingAt: retry.retrying_at.trim() || undefined,
});
const normalizeRetryState = (retry: TypesGen.ChatStreamRetry): RetryState => {
const delayMs = asNumber(retry.delay_ms);
const retryingAt = asString(retry.retrying_at).trim() || undefined;
return {
attempt: Math.max(1, asNumber(retry.attempt) ?? 1),
error: asString(retry.error).trim() || "Retrying request shortly.",
kind: asString(retry.kind).trim() || "generic",
provider: asString(retry.provider).trim() || undefined,
...(delayMs !== undefined ? { delayMs } : {}),
...(retryingAt ? { retryingAt } : {}),
};
};
const shouldSurfaceReconnectState = (state: ChatStoreState): boolean =>
state.streamError === null &&
@@ -206,9 +233,10 @@ export const useChatStore = (
const fetchedIDs = new Set(chatMessages.map((m) => m.id));
// Only classify a store-held ID as stale if it was
// present in the PREVIOUS sync's fetched data. IDs
// added to the store after the last sync (for example
// by the WS handler) are new, not stale, and must not
// trigger the destructive replaceMessages path.
// added to the store after the last sync (by the WS
// handler or handleSend) are new, not stale, and
// must not trigger the destructive replaceMessages
// path.
const prevIDs = new Set(prev.map((m) => m.id));
const hasStaleEntries =
contentChanged &&
@@ -391,7 +419,7 @@ export const useChatStore = (
};
const handleMessage = (
payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]>,
payload: OneWayMessageEvent<TypesGen.ServerSentEvent>,
) => {
if (disposed) {
return;
@@ -403,8 +431,11 @@ export const useChatStore = (
});
return;
}
if (payload.parsedMessage.type !== "data") {
return;
}
const streamEvents = payload.parsedMessage;
const streamEvents = toChatStreamEvents(payload.parsedMessage.data);
if (streamEvents.length === 0) {
return;
}
@@ -21,7 +21,6 @@ import {
} from "#/components/Tooltip/Tooltip";
import { formatTokenCount } from "#/utils/analytics";
import { formatCostMicros } from "#/utils/currency";
import { paginateItems } from "#/utils/paginateItems";
interface ChatCostSummaryViewProps {
summary: TypesGen.ChatCostSummary | undefined;
@@ -96,19 +95,25 @@ export const ChatCostSummaryView: FC<ChatCostSummaryViewProps> = ({
}
const modelPageSize = 10;
const {
pagedItems: pagedModels,
clampedPage: clampedModelPage,
hasPreviousPage: hasModelPrev,
hasNextPage: hasModelNext,
} = paginateItems(summary.by_model, modelPageSize, modelPage);
const modelMaxPage = Math.max(
1,
Math.ceil(summary.by_model.length / modelPageSize),
);
const clampedModelPage = Math.min(modelPage, modelMaxPage);
const pagedModels = summary.by_model.slice(
(clampedModelPage - 1) * modelPageSize,
clampedModelPage * modelPageSize,
);
const chatPageSize = 10;
const {
pagedItems: pagedChats,
clampedPage: clampedChatPage,
hasPreviousPage: hasChatPrev,
hasNextPage: hasChatNext,
} = paginateItems(summary.by_chat, chatPageSize, chatPage);
const chatMaxPage = Math.max(
1,
Math.ceil(summary.by_chat.length / chatPageSize),
);
const clampedChatPage = Math.min(chatPage, chatMaxPage);
const pagedChats = summary.by_chat.slice(
(clampedChatPage - 1) * chatPageSize,
clampedChatPage * chatPageSize,
);
const usageLimit = summary.usage_limit;
const showUsageLimitCard = usageLimit?.is_limited === true;
@@ -328,8 +333,10 @@ export const ChatCostSummaryView: FC<ChatCostSummaryViewProps> = ({
currentPage={clampedModelPage}
pageSize={modelPageSize}
onPageChange={setModelPage}
hasPreviousPage={hasModelPrev}
hasNextPage={hasModelNext}
hasPreviousPage={clampedModelPage > 1}
hasNextPage={
clampedModelPage * modelPageSize < summary.by_model.length
}
/>
</div>
)}
@@ -396,8 +403,10 @@ export const ChatCostSummaryView: FC<ChatCostSummaryViewProps> = ({
currentPage={clampedChatPage}
pageSize={chatPageSize}
onPageChange={setChatPage}
hasPreviousPage={hasChatPrev}
hasNextPage={hasChatNext}
hasPreviousPage={clampedChatPage > 1}
hasNextPage={
clampedChatPage * chatPageSize < summary.by_chat.length
}
/>
</div>
)}
@@ -48,6 +48,7 @@ interface ChatPageTimelineProps {
fileBlocks?: readonly TypesGen.ChatMessagePart[],
) => void;
editingMessageId?: number | null;
savingMessageId?: number | null;
urlTransform?: UrlTransform;
mcpServers?: readonly TypesGen.MCPServerConfig[];
}
@@ -58,6 +59,7 @@ export const ChatPageTimeline: FC<ChatPageTimelineProps> = ({
persistedError,
onEditUserMessage,
editingMessageId,
savingMessageId,
urlTransform,
mcpServers,
}) => {
@@ -84,10 +86,7 @@ export const ChatPageTimeline: FC<ChatPageTimelineProps> = ({
return (
<Profiler id="AgentChat" onRender={onRenderProfiler}>
<div
data-testid="chat-timeline-wrapper"
className="mx-auto flex w-full max-w-3xl flex-col gap-2 py-6"
>
<div className="mx-auto flex w-full max-w-3xl flex-col gap-3 py-6">
{/* VNC sessions for completed agents may already be
terminated, so inline desktop previews are disabled
via showDesktopPreviews={false} to avoid a perpetual
@@ -98,6 +97,7 @@ export const ChatPageTimeline: FC<ChatPageTimelineProps> = ({
subagentTitles={subagentTitles}
onEditUserMessage={onEditUserMessage}
editingMessageId={editingMessageId}
savingMessageId={savingMessageId}
urlTransform={urlTransform}
mcpServers={mcpServers}
computerUseSubagentIds={computerUseSubagentIds}
@@ -118,18 +118,10 @@ export const ChatPageTimeline: FC<ChatPageTimelineProps> = ({
);
};
export type PendingAttachment = {
fileId: string;
mediaType: string;
};
interface ChatPageInputProps {
store: ChatStoreHandle;
compressionThreshold: number | undefined;
onSend: (
message: string,
attachments?: readonly PendingAttachment[],
) => Promise<void> | void;
onSend: (message: string, fileIds?: string[]) => void;
onDeleteQueuedMessage: (id: number) => Promise<void>;
onPromoteQueuedMessage: (id: number) => Promise<void>;
onInterrupt: () => void;
@@ -320,10 +312,9 @@ export const ChatPageInput: FC<ChatPageInputProps> = ({
<AgentChatInput
onSend={(message) => {
void (async () => {
// Collect uploaded attachment metadata for the optimistic
// transcript builder while keeping the server payload
// shape unchanged downstream.
const pendingAttachments: PendingAttachment[] = [];
// Collect file IDs from already-uploaded attachments.
// Skip files in error state (e.g. too large).
const fileIds: string[] = [];
let skippedErrors = 0;
for (const file of attachments) {
const state = uploadStates.get(file);
@@ -332,10 +323,7 @@ export const ChatPageInput: FC<ChatPageInputProps> = ({
continue;
}
if (state?.status === "uploaded" && state.fileId) {
pendingAttachments.push({
fileId: state.fileId,
mediaType: file.type || "application/octet-stream",
});
fileIds.push(state.fileId);
}
}
if (skippedErrors > 0) {
@@ -343,10 +331,9 @@ export const ChatPageInput: FC<ChatPageInputProps> = ({
`${skippedErrors} attachment${skippedErrors > 1 ? "s" : ""} could not be sent (upload failed)`,
);
}
const attachmentArg =
pendingAttachments.length > 0 ? pendingAttachments : undefined;
const fileArg = fileIds.length > 0 ? fileIds : undefined;
try {
await onSend(message, attachmentArg);
await onSend(message, fileArg);
} catch {
// Attachments preserved for retry on failure.
return;
@@ -1,7 +1,7 @@
import dayjs from "dayjs";
import relativeTime from "dayjs/plugin/relativeTime";
import { CodeIcon, ExternalLinkIcon } from "lucide-react";
import { type FC, useState } from "react";
import type { FC } from "react";
import { Area, AreaChart, CartesianGrid, XAxis, YAxis } from "recharts";
import type * as TypesGen from "#/api/typesGenerated";
import { Button } from "#/components/Button/Button";
@@ -11,7 +11,6 @@ import {
ChartTooltip,
ChartTooltipContent,
} from "#/components/Chart/Chart";
import { PaginationWidgetBase } from "#/components/PaginationWidget/PaginationWidgetBase";
import {
Table,
TableBody,
@@ -22,7 +21,6 @@ import {
} from "#/components/Table/Table";
import { cn } from "#/utils/cn";
import { formatCostMicros } from "#/utils/currency";
import { paginateItems } from "#/utils/paginateItems";
import { PrStateIcon } from "./GitPanel/GitPanel";
dayjs.extend(relativeTime);
@@ -288,8 +286,6 @@ const TimeRangeFilter: FC<{
// Main view
// ---------------------------------------------------------------------------
const RECENT_PRS_PAGE_SIZE = 10;
export const PRInsightsView: FC<PRInsightsViewProps> = ({
data,
timeRange,
@@ -298,18 +294,6 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
const { summary, time_series, by_model, recent_prs } = data;
const isEmpty = summary.total_prs_created === 0;
// Client-side pagination for recent PRs table.
// Page resets to 1 on data refresh because the parent unmounts this
// component during loading. Clamping ensures the page is valid if the
// list shrinks without a full remount.
const [recentPrsPage, setRecentPrsPage] = useState(1);
const {
pagedItems: pagedRecentPrs,
clampedPage: clampedRecentPrsPage,
hasPreviousPage: hasRecentPrsPrev,
hasNextPage: hasRecentPrsNext,
} = paginateItems(recent_prs, RECENT_PRS_PAGE_SIZE, recentPrsPage);
return (
<div className="space-y-8">
{/* ── Header ── */}
@@ -370,8 +354,8 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
</div>
</section>
{/* ── Model breakdown + Recent PRs ── */}
<div className="space-y-6">
{/* ── Model breakdown + Recent PRs side by side ── */}
<div className="grid grid-cols-1 gap-6 lg:grid-cols-2">
{/* ── Model performance (simplified) ── */}
{by_model.length > 0 && (
<section>
@@ -429,7 +413,7 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
{recent_prs.length > 0 && (
<section>
<div className="mb-4">
<SectionTitle>Pull requests</SectionTitle>
<SectionTitle>Recent</SectionTitle>
</div>
<div className="overflow-hidden rounded-lg border border-border-default">
<Table className="table-fixed text-sm">
@@ -452,7 +436,7 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
</TableRow>
</TableHeader>{" "}
<TableBody>
{pagedRecentPrs.map((pr) => (
{recent_prs.map((pr) => (
<TableRow
key={pr.chat_id}
className="border-t border-border-default transition-colors hover:bg-surface-secondary/50"
@@ -496,18 +480,6 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
</TableBody>
</Table>
</div>
{recent_prs.length > RECENT_PRS_PAGE_SIZE && (
<div className="pt-4">
<PaginationWidgetBase
totalRecords={recent_prs.length}
currentPage={clampedRecentPrsPage}
pageSize={RECENT_PRS_PAGE_SIZE}
onPageChange={setRecentPrsPage}
hasPreviousPage={hasRecentPrsPrev}
hasNextPage={hasRecentPrsNext}
/>
</div>
)}
</section>
)}
</div>
@@ -1,7 +1,10 @@
import type { Meta, StoryObj } from "@storybook/react-vite";
import { action } from "storybook/actions";
import { expect, screen, userEvent, waitFor } from "storybook/test";
import { getProvisionerDaemonsKey } from "#/api/queries/organizations";
import { screen, userEvent } from "storybook/test";
import {
getProvisionerDaemonsKey,
organizationsKey,
} from "#/api/queries/organizations";
import {
MockDefaultOrganization,
MockOrganization2,
@@ -58,20 +61,40 @@ export const StarterTemplateWithOrgPicker: Story = {
},
};
// Query key used by permittedOrganizations() in the form.
const permittedOrgsKey = [
"organizations",
"permitted",
{ object: { resource_type: "template" }, action: "create" },
];
const canCreateTemplate = (organizationId: string) => {
return {
[organizationId]: {
object: {
resource_type: "template",
organization_id: organizationId,
},
action: "create",
},
};
};
export const StarterTemplateWithProvisionerWarning: Story = {
parameters: {
queries: [
{
key: permittedOrgsKey,
key: organizationsKey,
data: [MockDefaultOrganization, MockOrganization2],
},
{
key: [
"authorization",
{
checks: {
...canCreateTemplate(MockDefaultOrganization.id),
...canCreateTemplate(MockOrganization2.id),
},
},
],
data: {
[MockDefaultOrganization.id]: true,
[MockOrganization2.id]: true,
},
},
{
key: getProvisionerDaemonsKey(MockOrganization2.id),
data: [],
@@ -94,11 +117,27 @@ export const StarterTemplatePermissionsCheck: Story = {
parameters: {
queries: [
{
// Only MockDefaultOrganization passes the permission
// check; MockOrganization2 is filtered out by the
// permittedOrganizations query.
key: permittedOrgsKey,
data: [MockDefaultOrganization],
key: organizationsKey,
data: [MockDefaultOrganization, MockOrganization2],
},
{
key: [
"authorization",
{
checks: {
...canCreateTemplate(MockDefaultOrganization.id),
...canCreateTemplate(MockOrganization2.id),
},
},
],
data: {
[MockDefaultOrganization.id]: true,
[MockOrganization2.id]: false,
},
},
{
key: getProvisionerDaemonsKey(MockOrganization2.id),
data: [],
},
],
},
@@ -107,14 +146,7 @@ export const StarterTemplatePermissionsCheck: Story = {
showOrganizationPicker: true,
},
play: async () => {
// When only one org passes the permission check, it should be
// auto-selected in the picker.
const organizationPicker = screen.getByTestId("organization-autocomplete");
await waitFor(() =>
expect(organizationPicker).toHaveTextContent(
MockDefaultOrganization.display_name,
),
);
await userEvent.click(organizationPicker);
},
};
@@ -7,10 +7,7 @@ import { type FC, useState } from "react";
import { useQuery } from "react-query";
import { useSearchParams } from "react-router";
import * as Yup from "yup";
import {
permittedOrganizations,
provisionerDaemons,
} from "#/api/queries/organizations";
import { provisionerDaemons } from "#/api/queries/organizations";
import type {
CreateTemplateVersionRequest,
Organization,
@@ -194,10 +191,6 @@ type CreateTemplateFormProps = (
showOrganizationPicker?: boolean;
};
// Stable reference for empty org options to avoid re-render loops
// in the render-time state adjustment pattern.
const emptyOrgs: Organization[] = [];
export const CreateTemplateForm: FC<CreateTemplateFormProps> = (props) => {
const [searchParams] = useSearchParams();
const [selectedOrg, setSelectedOrg] = useState<Organization | null>(null);
@@ -229,34 +222,6 @@ export const CreateTemplateForm: FC<CreateTemplateFormProps> = (props) => {
});
const getFieldHelpers = getFormHelpers<CreateTemplateFormData>(form, error);
const permittedOrgsQuery = useQuery({
...permittedOrganizations({
object: { resource_type: "template" },
action: "create",
}),
enabled: Boolean(showOrganizationPicker),
});
const orgOptions = permittedOrgsQuery.data ?? emptyOrgs;
// Clear invalid selections when permission filtering removes the
// selected org. Uses the React render-time adjustment pattern.
const [prevOrgOptions, setPrevOrgOptions] = useState(orgOptions);
if (orgOptions !== prevOrgOptions) {
setPrevOrgOptions(orgOptions);
if (selectedOrg && !orgOptions.some((o) => o.id === selectedOrg.id)) {
setSelectedOrg(null);
void form.setFieldValue("organization", "");
}
}
// Auto-select when exactly one org is available and nothing is
// selected. Runs every render (not gated on options change) so it
// works when mock data is available synchronously on first render.
if (orgOptions.length === 1 && selectedOrg === null) {
setSelectedOrg(orgOptions[0]);
void form.setFieldValue("organization", orgOptions[0].name || "");
}
const { data: provisioners } = useQuery({
...provisionerDaemons(selectedOrg?.id ?? ""),
enabled: showOrganizationPicker && Boolean(selectedOrg),
@@ -298,10 +263,9 @@ export const CreateTemplateForm: FC<CreateTemplateFormProps> = (props) => {
<div className="flex flex-col gap-2">
<Label htmlFor="organization">Organization</Label>
<OrganizationAutocomplete
{...getFieldHelpers("organization")}
id="organization"
required
value={selectedOrg}
options={orgOptions}
onChange={(newValue) => {
setSelectedOrg(newValue);
void form.setFieldValue(
@@ -309,6 +273,10 @@ export const CreateTemplateForm: FC<CreateTemplateFormProps> = (props) => {
newValue?.name || "",
);
}}
check={{
object: { resource_type: "template" },
action: "create",
}}
/>
</div>
</>
@@ -1,6 +1,8 @@
import type { Meta, StoryObj } from "@storybook/react-vite";
import { action } from "storybook/actions";
import { userEvent, within } from "storybook/test";
import { organizationsKey } from "#/api/queries/organizations";
import type { Organization } from "#/api/typesGenerated";
import {
MockOrganization,
MockOrganization2,
@@ -24,20 +26,37 @@ type Story = StoryObj<typeof CreateUserForm>;
export const Ready: Story = {};
// Query key used by permittedOrganizations() in the form.
const permittedOrgsKey = [
"organizations",
"permitted",
{ object: { resource_type: "organization_member" }, action: "create" },
];
const permissionCheckQuery = (organizations: Organization[]) => {
return {
key: [
"authorization",
{
checks: Object.fromEntries(
organizations.map((org) => [
org.id,
{
action: "create",
object: {
resource_type: "organization_member",
organization_id: org.id,
},
},
]),
),
},
],
data: Object.fromEntries(organizations.map((org) => [org.id, true])),
};
};
export const WithOrganizations: Story = {
parameters: {
queries: [
{
key: permittedOrgsKey,
key: organizationsKey,
data: [MockOrganization, MockOrganization2],
},
permissionCheckQuery([MockOrganization, MockOrganization2]),
],
},
args: {
@@ -1,11 +1,9 @@
import { useFormik } from "formik";
import { Check } from "lucide-react";
import { Select as SelectPrimitive } from "radix-ui";
import { type FC, useState } from "react";
import { useQuery } from "react-query";
import type { FC } from "react";
import * as Yup from "yup";
import { hasApiFieldErrors, isApiError } from "#/api/errors";
import { permittedOrganizations } from "#/api/queries/organizations";
import type * as TypesGen from "#/api/typesGenerated";
import { ErrorAlert } from "#/components/Alert/ErrorAlert";
import { Button } from "#/components/Button/Button";
@@ -92,10 +90,6 @@ interface CreateUserFormProps {
serviceAccountsEnabled: boolean;
}
// Stable reference for empty org options to avoid re-render loops
// in the render-time state adjustment pattern.
const emptyOrgs: TypesGen.Organization[] = [];
export const CreateUserForm: FC<CreateUserFormProps> = ({
error,
isLoading,
@@ -131,38 +125,6 @@ export const CreateUserForm: FC<CreateUserFormProps> = ({
enableReinitialize: true,
});
const [selectedOrg, setSelectedOrg] = useState<TypesGen.Organization | null>(
null,
);
const permittedOrgsQuery = useQuery({
...permittedOrganizations({
object: { resource_type: "organization_member" },
action: "create",
}),
enabled: showOrganizations,
});
const orgOptions = permittedOrgsQuery.data ?? emptyOrgs;
// Clear invalid selections when permission filtering removes the
// selected org. Uses the React render-time adjustment pattern.
const [prevOrgOptions, setPrevOrgOptions] = useState(orgOptions);
if (orgOptions !== prevOrgOptions) {
setPrevOrgOptions(orgOptions);
if (selectedOrg && !orgOptions.some((o) => o.id === selectedOrg.id)) {
setSelectedOrg(null);
void form.setFieldValue("organization", "");
}
}
// Auto-select when exactly one org is available and nothing is
// selected. Runs every render (not gated on options change) so it
// works when mock data is available synchronously on first render.
if (orgOptions.length === 1 && selectedOrg === null) {
setSelectedOrg(orgOptions[0]);
void form.setFieldValue("organization", orgOptions[0].id ?? "");
}
const getFieldHelpers = getFormHelpers(form, error);
const isServiceAccount = form.values.login_type === "none";
@@ -212,14 +174,16 @@ export const CreateUserForm: FC<CreateUserFormProps> = ({
<div className="flex flex-col gap-2">
<Label htmlFor="organization">Organization</Label>
<OrganizationAutocomplete
{...getFieldHelpers("organization")}
id="organization"
required
value={selectedOrg}
options={orgOptions}
onChange={(newValue) => {
setSelectedOrg(newValue);
void form.setFieldValue("organization", newValue?.id ?? "");
}}
check={{
object: { resource_type: "organization_member" },
action: "create",
}}
/>
</div>
)}
-64
View File
@@ -1,64 +0,0 @@
import { describe, expect, it } from "vitest";
import { paginateItems } from "./paginateItems";
// 25 items numbered 125 for readable assertions.
const items = Array.from({ length: 25 }, (_, i) => i + 1);
describe("paginateItems", () => {
it("returns the first page of items", () => {
const result = paginateItems(items, 10, 1);
expect(result.pagedItems).toEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
expect(result.clampedPage).toBe(1);
expect(result.totalPages).toBe(3);
expect(result.hasPreviousPage).toBe(false);
expect(result.hasNextPage).toBe(true);
});
it("returns a partial last page", () => {
const result = paginateItems(items, 10, 3);
expect(result.pagedItems).toEqual([21, 22, 23, 24, 25]);
expect(result.clampedPage).toBe(3);
expect(result.totalPages).toBe(3);
expect(result.hasPreviousPage).toBe(true);
expect(result.hasNextPage).toBe(false);
});
it("clamps currentPage down when beyond total pages", () => {
const result = paginateItems(items, 10, 99);
expect(result.clampedPage).toBe(3);
expect(result.pagedItems).toEqual([21, 22, 23, 24, 25]);
expect(result.hasPreviousPage).toBe(true);
expect(result.hasNextPage).toBe(false);
});
it("clamps currentPage up when 0", () => {
const result = paginateItems(items, 10, 0);
expect(result.clampedPage).toBe(1);
expect(result.pagedItems).toEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
expect(result.hasPreviousPage).toBe(false);
expect(result.hasNextPage).toBe(true);
});
it("clamps currentPage up when negative", () => {
const result = paginateItems(items, 10, -5);
expect(result.clampedPage).toBe(1);
expect(result.pagedItems).toEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
expect(result.hasPreviousPage).toBe(false);
expect(result.hasNextPage).toBe(true);
});
it("returns empty pagedItems with clampedPage=1 for an empty array", () => {
const result = paginateItems([], 10, 1);
expect(result.pagedItems).toEqual([]);
expect(result.clampedPage).toBe(1);
expect(result.totalPages).toBe(1);
expect(result.hasPreviousPage).toBe(false);
expect(result.hasNextPage).toBe(false);
});
it("reports hasPreviousPage correctly for middle pages", () => {
const result = paginateItems(items, 10, 2);
expect(result.hasPreviousPage).toBe(true);
expect(result.hasNextPage).toBe(true);
});
});
-25
View File
@@ -1,25 +0,0 @@
export function paginateItems<T>(
items: readonly T[],
pageSize: number,
currentPage: number,
): {
pagedItems: T[];
clampedPage: number;
totalPages: number;
hasPreviousPage: boolean;
hasNextPage: boolean;
} {
const totalPages = Math.max(1, Math.ceil(items.length / pageSize));
const clampedPage = Math.max(1, Math.min(currentPage, totalPages));
const pagedItems = items.slice(
(clampedPage - 1) * pageSize,
clampedPage * pageSize,
);
return {
pagedItems,
clampedPage,
totalPages,
hasPreviousPage: clampedPage > 1,
hasNextPage: clampedPage * pageSize < items.length,
};
}