Compare commits

..

3 Commits

Author SHA1 Message Date
Danielle Maywood 37b0cbbcb7 feat(site): add DateRangePicker to analytics page 2026-04-07 17:52:09 +00:00
Danielle Maywood 669328ebd9 feat(site): wrap behavior page sections in CollapsibleSection cards 2026-04-07 17:51:55 +00:00
Danielle Maywood de846d8ea1 feat(site): add CollapsibleSection component and shared utilities 2026-04-07 17:48:04 +00:00
78 changed files with 2303 additions and 3026 deletions
-6
View File
@@ -14175,9 +14175,6 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -14499,9 +14496,6 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
-6
View File
@@ -12739,9 +12739,6 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -13042,9 +13039,6 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
+1 -8
View File
@@ -26,11 +26,6 @@ import (
"github.com/coder/coder/v2/codersdk"
)
// Limit the count query to avoid a slow sequential scan due to joins
// on a large table. Set to 0 to disable capping (but also see the note
// in the SQL query).
const auditLogCountCap = 2000
// @Summary Get audit logs
// @ID get-audit-logs
// @Security CoderSessionToken
@@ -71,7 +66,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
countFilter.Username = ""
}
countFilter.CountCap = auditLogCountCap
// Use the same filters to count the number of audit logs
count, err := api.Database.CountAuditLogs(ctx, countFilter)
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
@@ -86,7 +81,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: []codersdk.AuditLog{},
Count: 0,
CountCap: auditLogCountCap,
})
return
}
@@ -104,7 +98,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: api.convertAuditLogs(ctx, dblogs),
Count: count,
CountCap: auditLogCountCap,
})
}
+40 -27
View File
@@ -2155,12 +2155,17 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
return q.db.DeleteUserChatProviderKey(ctx, arg)
}
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// First get the secret to check ownership
secret, err := q.GetUserSecret(ctx, id)
if err != nil {
return err
}
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
return err
}
return q.db.DeleteUserSecret(ctx, id)
}
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
@@ -4123,6 +4128,19 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
return q.db.GetUserNotificationPreferences(ctx, userID)
}
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
// First get the secret to check ownership
secret, err := q.db.GetUserSecret(ctx, id)
if err != nil {
return database.UserSecret{}, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
return database.UserSecret{}, err
}
return secret, nil
}
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
@@ -5506,7 +5524,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
return q.db.ListUserChatCompactionThresholds(ctx, userID)
}
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
return nil, err
@@ -5514,16 +5532,6 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
return q.db.ListUserSecrets(ctx, userID)
}
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
// This query returns decrypted secret values and must only be called
// from system contexts (provisioner, agent manifest). REST API
// handlers should use ListUserSecrets (metadata only).
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.ListUserSecretsWithValues(ctx, userID)
}
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
if err != nil {
@@ -5774,15 +5782,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
return q.db.UpdateChatByID(ctx, arg)
}
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
// The batch heartbeat is a system-level operation filtered by
// worker_id. Authorization is enforced by the AsChatd context
// at the call site rather than per-row, because checking each
// row individually would defeat the purpose of batching.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return nil, err
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeats(ctx, arg)
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeat(ctx, arg)
}
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
@@ -6624,12 +6632,17 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
return q.db.UpdateUserRoles(ctx, arg)
}
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
// First get the secret to check ownership
secret, err := q.db.GetUserSecret(ctx, arg.ID)
if err != nil {
return database.UserSecret{}, err
}
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
return database.UserSecret{}, err
}
return q.db.UpdateUserSecret(ctx, arg)
}
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
+29 -29
View File
@@ -842,15 +842,15 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
resultID := uuid.New()
arg := database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{resultID},
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: uuid.New(),
Now: time.Now(),
}
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
}))
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -5346,20 +5346,19 @@ func (s *MethodTestSuite) TestUserSecrets() {
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
Returns(secret)
}))
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
check.Args(secret.ID).
Asserts(secret, policy.ActionRead).
Returns(secret)
}))
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
check.Args(user.ID).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
Returns([]database.ListUserSecretsRow{row})
}))
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
check.Args(user.ID).
Asserts(rbac.ResourceSystem, policy.ActionRead).
Returns([]database.UserSecret{secret})
}))
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
@@ -5371,21 +5370,22 @@ func (s *MethodTestSuite) TestUserSecrets() {
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
Returns(ret)
}))
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
arg := database.UpdateUserSecretParams{ID: secret.ID}
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
check.Args(arg).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
Asserts(secret, policy.ActionUpdate).
Returns(updated)
}))
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
check.Args(secret.ID).
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
Returns()
}))
}
-1
View File
@@ -1597,7 +1597,6 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
Name: takeFirst(seed.Name, "secret-name"),
Description: takeFirst(seed.Description, "secret description"),
Value: takeFirst(seed.Value, "secret value"),
ValueKeyID: seed.ValueKeyID,
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
})
+21 -21
View File
@@ -712,11 +712,11 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
return r0
}
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
r0 := m.s.DeleteUserSecret(ctx, id)
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
return r0
}
@@ -2624,6 +2624,14 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
return r0, r1
}
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.GetUserSecret(ctx, id)
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
@@ -3912,7 +3920,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.ListUserSecrets(ctx, userID)
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
@@ -3920,14 +3928,6 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
return r0, r1
}
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
return r0, r1
}
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
start := time.Now()
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
@@ -4136,11 +4136,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
return r0, r1
}
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
return r0, r1
}
@@ -4696,11 +4696,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
return r0, r1
}
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
return r0, r1
}
+36 -36
View File
@@ -1199,18 +1199,18 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
}
// DeleteUserSecretByUserIDAndName mocks base method.
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
// DeleteUserSecret mocks base method.
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
}
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
@@ -4907,6 +4907,21 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
}
// GetUserSecret mocks base method.
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
ret0, _ := ret[0].(database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserSecret indicates an expected call of GetUserSecret.
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
}
// GetUserSecretByUserIDAndName mocks base method.
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
@@ -7397,10 +7412,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
}
// ListUserSecrets mocks base method.
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
ret0, _ := ret[0].([]database.ListUserSecretsRow)
ret0, _ := ret[0].([]database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -7411,21 +7426,6 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
}
// ListUserSecretsWithValues mocks base method.
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
ret0, _ := ret[0].([]database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
}
// ListWorkspaceAgentPortShares mocks base method.
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
m.ctrl.T.Helper()
@@ -7835,19 +7835,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
}
// UpdateChatHeartbeats mocks base method.
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
// UpdateChatHeartbeat mocks base method.
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
ret0, _ := ret[0].([]uuid.UUID)
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
}
// UpdateChatLabelsByID mocks base method.
@@ -8854,19 +8854,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
}
// UpdateUserSecretByUserIDAndName mocks base method.
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
// UpdateUserSecret mocks base method.
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
ret0, _ := ret[0].(database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
}
// UpdateUserStatus mocks base method.
-2
View File
@@ -584,7 +584,6 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
arg.DateTo,
arg.BuildReason,
arg.RequestID,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -721,7 +720,6 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
arg.WorkspaceID,
arg.ConnectionID,
arg.Status,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -145,13 +145,5 @@ func extractWhereClause(query string) string {
// Remove SQL comments
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
// Normalize indentation so subquery wrapping doesn't cause
// mismatches.
lines := strings.Split(whereClause, "\n")
for i, line := range lines {
lines[i] = strings.TrimLeft(line, " \t")
}
whereClause = strings.Join(lines, "\n")
return strings.TrimSpace(whereClause)
}
+17 -27
View File
@@ -81,8 +81,8 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
}
func (q *msgQueue) run() {
var batch [maxDrainBatch]msgOrErr
for {
// wait until there is something on the queue or we are closed
q.cond.L.Lock()
for q.size == 0 && !q.closed {
q.cond.Wait()
@@ -91,32 +91,28 @@ func (q *msgQueue) run() {
q.cond.L.Unlock()
return
}
// Drain up to maxDrainBatch items while holding the lock.
n := min(q.size, maxDrainBatch)
for i := range n {
batch[i] = q.q[q.front]
q.front = (q.front + 1) % BufferSize
}
q.size -= n
item := q.q[q.front]
q.front = (q.front + 1) % BufferSize
q.size--
q.cond.L.Unlock()
// Dispatch each message individually without holding the lock.
for i := range n {
item := batch[i]
if item.err == nil {
if q.l != nil {
q.l(q.ctx, item.msg)
continue
}
if q.le != nil {
q.le(q.ctx, item.msg, nil)
continue
}
// process item without holding lock
if item.err == nil {
// real message
if q.l != nil {
q.l(q.ctx, item.msg)
continue
}
if q.le != nil {
q.le(q.ctx, nil, item.err)
q.le(q.ctx, item.msg, nil)
continue
}
// unhittable
continue
}
// if the listener wants errors, send it.
if q.le != nil {
q.le(q.ctx, nil, item.err)
}
}
}
@@ -237,12 +233,6 @@ type PGPubsub struct {
// for a subscriber before dropping messages.
const BufferSize = 2048
// maxDrainBatch is the maximum number of messages to drain from the ring
// buffer per iteration. Batching amortizes the cost of mutex
// acquire/release and cond.Wait across many messages, improving drain
// throughput during bursts.
const maxDrainBatch = 256
// Subscribe calls the listener when an event matching the name is received.
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
+7 -14
View File
@@ -152,7 +152,7 @@ type sqlcQuerier interface {
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
@@ -598,6 +598,7 @@ type sqlcQuerier interface {
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
// GetUserStatusCounts returns the count of users in each status over time.
// The time range is inclusively defined by the start_time and end_time parameters.
@@ -817,13 +818,7 @@ type sqlcQuerier interface {
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
// Returns metadata only (no value or value_key_id) for the
// REST API list and get endpoints.
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
// Returns all columns including the secret value. Used by the
// provisioner (build-time injection) and the agent manifest
// (runtime injection).
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
@@ -875,11 +870,9 @@ type sqlcQuerier interface {
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
// caller can detect stolen or completed chats via set-difference.
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
// Updates the cached injected context parts (AGENTS.md +
// skills) on the chat row. Called only when context changes
@@ -962,7 +955,7 @@ type sqlcQuerier interface {
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
+30 -45
View File
@@ -7339,7 +7339,13 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, secretID, createdSecret.ID)
// 2. READ by UserID and Name
// 2. READ by ID
readSecret, err := db.GetUserSecret(ctx, createdSecret.ID)
require.NoError(t, err)
assert.Equal(t, createdSecret.ID, readSecret.ID)
assert.Equal(t, "workflow-secret", readSecret.Name)
// 3. READ by UserID and Name
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
@@ -7347,43 +7353,33 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
require.NoError(t, err)
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
// 3. LIST (metadata only)
// 4. LIST
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
require.NoError(t, err)
require.Len(t, secrets, 1)
assert.Equal(t, createdSecret.ID, secrets[0].ID)
// 4. LIST with values
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
require.NoError(t, err)
require.Len(t, secretsWithValues, 1)
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
// 5. UPDATE (partial - only description)
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
UpdateDescription: true,
Description: "Updated workflow description",
// 5. UPDATE
updateParams := database.UpdateUserSecretParams{
ID: createdSecret.ID,
Description: "Updated workflow description",
Value: "updated-workflow-value",
EnvName: "UPDATED_WORKFLOW_ENV",
FilePath: "/updated/workflow/path",
}
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
updatedSecret, err := db.UpdateUserSecret(ctx, updateParams)
require.NoError(t, err)
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
assert.Equal(t, "updated-workflow-value", updatedSecret.Value)
// 6. DELETE
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
})
err = db.DeleteUserSecret(ctx, createdSecret.ID)
require.NoError(t, err)
// Verify deletion
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
_, err = db.GetUserSecret(ctx, createdSecret.ID)
require.Error(t, err)
assert.Contains(t, err.Error(), "no rows in result set")
@@ -7453,13 +7449,9 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
})
// Verify both secrets exist
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID, Name: secret1.Name,
})
_, err = db.GetUserSecret(ctx, secret1.ID)
require.NoError(t, err)
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID, Name: secret2.Name,
})
_, err = db.GetUserSecret(ctx, secret2.ID)
require.NoError(t, err)
})
}
@@ -7482,14 +7474,14 @@ func TestUserSecretsAuthorization(t *testing.T) {
org := dbgen.Organization(t, db, database.Organization{})
// Create secrets for users
_ = dbgen.UserSecret(t, db, database.UserSecret{
user1Secret := dbgen.UserSecret(t, db, database.UserSecret{
UserID: user1.ID,
Name: "user1-secret",
Description: "User 1's secret",
Value: "user1-value",
})
_ = dbgen.UserSecret(t, db, database.UserSecret{
user2Secret := dbgen.UserSecret(t, db, database.UserSecret{
UserID: user2.ID,
Name: "user2-secret",
Description: "User 2's secret",
@@ -7499,8 +7491,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
testCases := []struct {
name string
subject rbac.Subject
lookupUserID uuid.UUID
lookupName string
secretID uuid.UUID
expectedAccess bool
}{
{
@@ -7510,8 +7501,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
Scope: rbac.ScopeAll,
},
lookupUserID: user1.ID,
lookupName: "user1-secret",
secretID: user1Secret.ID,
expectedAccess: true,
},
{
@@ -7521,8 +7511,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
Scope: rbac.ScopeAll,
},
lookupUserID: user2.ID,
lookupName: "user2-secret",
secretID: user2Secret.ID,
expectedAccess: false,
},
{
@@ -7532,8 +7521,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
Scope: rbac.ScopeAll,
},
lookupUserID: user1.ID,
lookupName: "user1-secret",
secretID: user1Secret.ID,
expectedAccess: false,
},
{
@@ -7543,8 +7531,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
Scope: rbac.ScopeAll,
},
lookupUserID: user1.ID,
lookupName: "user1-secret",
secretID: user1Secret.ID,
expectedAccess: false,
},
}
@@ -7556,10 +7543,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
authCtx := dbauthz.As(ctx, tc.subject)
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
UserID: tc.lookupUserID,
Name: tc.lookupName,
})
// Test GetUserSecret
_, err := authDB.GetUserSecret(authCtx, tc.secretID)
if tc.expectedAccess {
require.NoError(t, err, "expected access to be granted")
+259 -362
View File
@@ -2275,105 +2275,93 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
}
const countAuditLogs = `-- name: CountAuditLogs :one
SELECT COUNT(*) FROM (
SELECT 1
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN $1::text != '' THEN resource_type = $1::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
ELSE true
END
-- Filter organization_id
AND CASE
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN $4::text != '' THEN resource_target = $4
ELSE true
END
-- Filter action
AND CASE
WHEN $5::text != '' THEN action = $5::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower($7)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8::text != '' THEN users.email = $8
ELSE true
END
-- Filter by date_from
AND CASE
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
ELSE true
END
-- Filter by date_to
AND CASE
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
ELSE true
END
-- Filter request_id
AND CASE
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
-- Avoid a slow scan on a large table with joins. The caller
-- passes the count cap and we add 1 so the frontend can detect
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
-- -> NULL + 1 = NULL).
-- NOTE: Parameterizing this so that we can easily change from,
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
-- here if disabling the capping on a large table permanently.
-- This way the PG planner can plan parallel execution for
-- potential large wins.
LIMIT NULLIF($13::int, 0) + 1
) AS limited_count
SELECT COUNT(*)
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN $1::text != '' THEN resource_type = $1::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
ELSE true
END
-- Filter organization_id
AND CASE
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN $4::text != '' THEN resource_target = $4
ELSE true
END
-- Filter action
AND CASE
WHEN $5::text != '' THEN action = $5::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower($7)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8::text != '' THEN users.email = $8
ELSE true
END
-- Filter by date_from
AND CASE
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
ELSE true
END
-- Filter by date_to
AND CASE
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
ELSE true
END
-- Filter request_id
AND CASE
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
`
type CountAuditLogsParams struct {
@@ -2389,7 +2377,6 @@ type CountAuditLogsParams struct {
DateTo time.Time `db:"date_to" json:"date_to"`
BuildReason string `db:"build_reason" json:"build_reason"`
RequestID uuid.UUID `db:"request_id" json:"request_id"`
CountCap int32 `db:"count_cap" json:"count_cap"`
}
func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) {
@@ -2406,7 +2393,6 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam
arg.DateTo,
arg.BuildReason,
arg.RequestID,
arg.CountCap,
)
var count int64
err := row.Scan(&count)
@@ -6615,49 +6601,30 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
return i, err
}
const updateChatHeartbeats = `-- name: UpdateChatHeartbeats :many
const updateChatHeartbeat = `-- name: UpdateChatHeartbeat :execrows
UPDATE
chats
SET
heartbeat_at = $1::timestamptz
heartbeat_at = NOW()
WHERE
id = ANY($2::uuid[])
AND worker_id = $3::uuid
id = $1::uuid
AND worker_id = $2::uuid
AND status = 'running'::chat_status
RETURNING id
`
type UpdateChatHeartbeatsParams struct {
Now time.Time `db:"now" json:"now"`
IDs []uuid.UUID `db:"ids" json:"ids"`
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
type UpdateChatHeartbeatParams struct {
ID uuid.UUID `db:"id" json:"id"`
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
}
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
// caller can detect stolen or completed chats via set-difference.
func (q *sqlQuerier) UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats, arg.Now, pq.Array(arg.IDs), arg.WorkerID)
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) {
result, err := q.db.ExecContext(ctx, updateChatHeartbeat, arg.ID, arg.WorkerID)
if err != nil {
return nil, err
return 0, err
}
defer rows.Close()
var items []uuid.UUID
for rows.Next() {
var id uuid.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, id)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
return result.RowsAffected()
}
const updateChatLabelsByID = `-- name: UpdateChatLabelsByID :one
@@ -7604,113 +7571,110 @@ func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUps
}
const countConnectionLogs = `-- name: CountConnectionLogs :one
SELECT COUNT(*) AS count FROM (
SELECT 1
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = $1
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN $2 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower($2) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = $3
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN $4 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = $4 AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN $5 :: text != '' THEN
type = $5 :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7 :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower($7) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8 :: text != '' THEN
users.email = $8
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= $9
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= $10
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = $11
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = $12
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN $13 :: text != '' THEN
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
-- NOTE: See the CountAuditLogs LIMIT note.
LIMIT NULLIF($14::int, 0) + 1
) AS limited_count
SELECT
COUNT(*) AS count
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = $1
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN $2 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower($2) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = $3
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN $4 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = $4 AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN $5 :: text != '' THEN
type = $5 :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7 :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower($7) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8 :: text != '' THEN
users.email = $8
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= $9
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= $10
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = $11
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = $12
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN $13 :: text != '' THEN
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
`
type CountConnectionLogsParams struct {
@@ -7727,7 +7691,6 @@ type CountConnectionLogsParams struct {
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"`
Status string `db:"status" json:"status"`
CountCap int32 `db:"count_cap" json:"count_cap"`
}
func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) {
@@ -7745,7 +7708,6 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio
arg.WorkspaceID,
arg.ConnectionID,
arg.Status,
arg.CountCap,
)
var count int64
err := row.Scan(&count)
@@ -22639,30 +22601,21 @@ INSERT INTO user_secrets (
name,
description,
value,
value_key_id,
env_name,
file_path
) VALUES (
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8
$1, $2, $3, $4, $5, $6, $7
) RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
`
type CreateUserSecretParams struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
}
func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretParams) (UserSecret, error) {
@@ -22672,7 +22625,6 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
arg.Name,
arg.Description,
arg.Value,
arg.ValueKeyID,
arg.EnvName,
arg.FilePath,
)
@@ -22692,24 +22644,41 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
return i, err
}
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :exec
const deleteUserSecret = `-- name: DeleteUserSecret :exec
DELETE FROM user_secrets
WHERE user_id = $1 AND name = $2
WHERE id = $1
`
type DeleteUserSecretByUserIDAndNameParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
}
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error {
_, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
func (q *sqlQuerier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteUserSecret, id)
return err
}
const getUserSecret = `-- name: GetUserSecret :one
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
WHERE id = $1
`
func (q *sqlQuerier) GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error) {
row := q.db.QueryRowContext(ctx, getUserSecret, id)
var i UserSecret
err := row.Scan(
&i.ID,
&i.UserID,
&i.Name,
&i.Description,
&i.Value,
&i.EnvName,
&i.FilePath,
&i.CreatedAt,
&i.UpdatedAt,
&i.ValueKeyID,
)
return i, err
}
const getUserSecretByUserIDAndName = `-- name: GetUserSecretByUserIDAndName :one
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
FROM user_secrets
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
WHERE user_id = $1 AND name = $2
`
@@ -22737,76 +22706,17 @@ func (q *sqlQuerier) GetUserSecretByUserIDAndName(ctx context.Context, arg GetUs
}
const listUserSecrets = `-- name: ListUserSecrets :many
SELECT
id, user_id, name, description,
env_name, file_path,
created_at, updated_at
FROM user_secrets
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
WHERE user_id = $1
ORDER BY name ASC
`
type ListUserSecretsRow struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// Returns metadata only (no value or value_key_id) for the
// REST API list and get endpoints.
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error) {
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
rows, err := q.db.QueryContext(ctx, listUserSecrets, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListUserSecretsRow
for rows.Next() {
var i ListUserSecretsRow
if err := rows.Scan(
&i.ID,
&i.UserID,
&i.Name,
&i.Description,
&i.EnvName,
&i.FilePath,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listUserSecretsWithValues = `-- name: ListUserSecretsWithValues :many
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
FROM user_secrets
WHERE user_id = $1
ORDER BY name ASC
`
// Returns all columns including the secret value. Used by the
// provisioner (build-time injection) and the agent manifest
// (runtime injection).
func (q *sqlQuerier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
rows, err := q.db.QueryContext(ctx, listUserSecretsWithValues, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []UserSecret
for rows.Next() {
var i UserSecret
@@ -22835,46 +22745,33 @@ func (q *sqlQuerier) ListUserSecretsWithValues(ctx context.Context, userID uuid.
return items, nil
}
const updateUserSecretByUserIDAndName = `-- name: UpdateUserSecretByUserIDAndName :one
const updateUserSecret = `-- name: UpdateUserSecret :one
UPDATE user_secrets
SET
value = CASE WHEN $1::bool THEN $2 ELSE value END,
value_key_id = CASE WHEN $1::bool THEN $3 ELSE value_key_id END,
description = CASE WHEN $4::bool THEN $5 ELSE description END,
env_name = CASE WHEN $6::bool THEN $7 ELSE env_name END,
file_path = CASE WHEN $8::bool THEN $9 ELSE file_path END,
updated_at = CURRENT_TIMESTAMP
WHERE user_id = $10 AND name = $11
description = $2,
value = $3,
env_name = $4,
file_path = $5,
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
`
type UpdateUserSecretByUserIDAndNameParams struct {
UpdateValue bool `db:"update_value" json:"update_value"`
Value string `db:"value" json:"value"`
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
UpdateDescription bool `db:"update_description" json:"update_description"`
Description string `db:"description" json:"description"`
UpdateEnvName bool `db:"update_env_name" json:"update_env_name"`
EnvName string `db:"env_name" json:"env_name"`
UpdateFilePath bool `db:"update_file_path" json:"update_file_path"`
FilePath string `db:"file_path" json:"file_path"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
type UpdateUserSecretParams struct {
ID uuid.UUID `db:"id" json:"id"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
}
func (q *sqlQuerier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error) {
row := q.db.QueryRowContext(ctx, updateUserSecretByUserIDAndName,
arg.UpdateValue,
arg.Value,
arg.ValueKeyID,
arg.UpdateDescription,
func (q *sqlQuerier) UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error) {
row := q.db.QueryRowContext(ctx, updateUserSecret,
arg.ID,
arg.Description,
arg.UpdateEnvName,
arg.Value,
arg.EnvName,
arg.UpdateFilePath,
arg.FilePath,
arg.UserID,
arg.Name,
)
var i UserSecret
err := row.Scan(
+88 -99
View File
@@ -149,105 +149,94 @@ VALUES (
RETURNING *;
-- name: CountAuditLogs :one
SELECT COUNT(*) FROM (
SELECT 1
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
-- Avoid a slow scan on a large table with joins. The caller
-- passes the count cap and we add 1 so the frontend can detect
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
-- -> NULL + 1 = NULL).
-- NOTE: Parameterizing this so that we can easily change from,
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
-- here if disabling the capping on a large table permanently.
-- This way the PG planner can plan parallel execution for
-- potential large wins.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
SELECT COUNT(*)
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
;
-- name: DeleteOldAuditLogConnectionEvents :exec
DELETE FROM audit_logs
+6 -9
View File
@@ -674,20 +674,17 @@ WHERE
status = 'running'::chat_status
AND heartbeat_at < @stale_threshold::timestamptz;
-- name: UpdateChatHeartbeats :many
-- Bumps the heartbeat timestamp for the given set of chat IDs,
-- provided they are still running and owned by the specified
-- worker. Returns the IDs that were actually updated so the
-- caller can detect stolen or completed chats via set-difference.
-- name: UpdateChatHeartbeat :execrows
-- Bumps the heartbeat timestamp for a running chat so that other
-- replicas know the worker is still alive.
UPDATE
chats
SET
heartbeat_at = @now::timestamptz
heartbeat_at = NOW()
WHERE
id = ANY(@ids::uuid[])
id = @id::uuid
AND worker_id = @worker_id::uuid
AND status = 'running'::chat_status
RETURNING id;
AND status = 'running'::chat_status;
-- name: GetChatDiffStatusByChatID :one
SELECT
+105 -107
View File
@@ -133,113 +133,111 @@ OFFSET
@offset_opt;
-- name: CountConnectionLogs :one
SELECT COUNT(*) AS count FROM (
SELECT 1
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
-- NOTE: See the CountAuditLogs LIMIT note.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
SELECT
COUNT(*) AS count
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
;
-- name: DeleteOldConnectionLogs :execrows
WITH old_logs AS (
+18 -39
View File
@@ -1,26 +1,14 @@
-- name: GetUserSecretByUserIDAndName :one
SELECT *
FROM user_secrets
WHERE user_id = @user_id AND name = @name;
SELECT * FROM user_secrets
WHERE user_id = $1 AND name = $2;
-- name: GetUserSecret :one
SELECT * FROM user_secrets
WHERE id = $1;
-- name: ListUserSecrets :many
-- Returns metadata only (no value or value_key_id) for the
-- REST API list and get endpoints.
SELECT
id, user_id, name, description,
env_name, file_path,
created_at, updated_at
FROM user_secrets
WHERE user_id = @user_id
ORDER BY name ASC;
-- name: ListUserSecretsWithValues :many
-- Returns all columns including the secret value. Used by the
-- provisioner (build-time injection) and the agent manifest
-- (runtime injection).
SELECT *
FROM user_secrets
WHERE user_id = @user_id
SELECT * FROM user_secrets
WHERE user_id = $1
ORDER BY name ASC;
-- name: CreateUserSecret :one
@@ -30,32 +18,23 @@ INSERT INTO user_secrets (
name,
description,
value,
value_key_id,
env_name,
file_path
) VALUES (
@id,
@user_id,
@name,
@description,
@value,
@value_key_id,
@env_name,
@file_path
$1, $2, $3, $4, $5, $6, $7
) RETURNING *;
-- name: UpdateUserSecretByUserIDAndName :one
-- name: UpdateUserSecret :one
UPDATE user_secrets
SET
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
updated_at = CURRENT_TIMESTAMP
WHERE user_id = @user_id AND name = @name
description = $2,
value = $3,
env_name = $4,
file_path = $5,
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
RETURNING *;
-- name: DeleteUserSecretByUserIDAndName :exec
-- name: DeleteUserSecret :exec
DELETE FROM user_secrets
WHERE user_id = @user_id AND name = @name;
WHERE id = $1;
-34
View File
@@ -298,40 +298,6 @@ neq(input.object.owner, "");
ExpectedSQL: p("'' = 'org-id'"),
VariableConverter: regosql.ChatConverter(),
},
{
Name: "AuditLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id IS NOT NULL") + " OR " +
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.AuditLogConverter(),
},
{
Name: "ConnectionLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id IS NOT NULL") + " OR " +
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.ConnectionLogConverter(),
},
}
for _, tc := range testCases {
+2 -2
View File
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
func AuditLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
// Audit logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
func ConnectionLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
// Connection logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
-114
View File
@@ -1,114 +0,0 @@
package sqltypes
import (
"fmt"
"strings"
"github.com/open-policy-agent/opa/ast"
"golang.org/x/xerrors"
)
var (
_ VariableMatcher = astUUIDVar{}
_ Node = astUUIDVar{}
_ SupportsEquality = astUUIDVar{}
)
// astUUIDVar is a variable that represents a UUID column. Unlike
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
// This allows PostgreSQL to use indexes on UUID columns.
type astUUIDVar struct {
Source RegoSource
FieldPath []string
ColumnString string
}
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
}
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
left, err := RegoVarPath(u.FieldPath, rego)
if err == nil && len(left) == 0 {
return astUUIDVar{
Source: RegoSource(rego.String()),
FieldPath: u.FieldPath,
ColumnString: u.ColumnString,
}, true
}
return nil, false
}
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
return u.ColumnString
}
// EqualsSQLString handles equality comparisons for UUID columns.
// Rego always produces string literals, so we accept AstString and
// cast the literal to ::uuid in the output SQL. This lets PG use
// native UUID indexes instead of falling back to text comparisons.
// nolint:revive
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
switch other.UseAs().(type) {
case AstString:
// The other side is a rego string literal like
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
// that casts the literal to uuid so PG can use indexes:
// column = 'val'::uuid
// instead of the text-based:
// 'val' = COALESCE(column::text, '')
s, ok := other.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString, got %T", other)
}
if s.Value == "" {
// Empty string in rego means "no value". Compare the
// column against NULL since UUID columns represent
// absent values as NULL, not empty strings.
op := "IS NULL"
if not {
op = "IS NOT NULL"
}
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
}
return fmt.Sprintf("%s %s '%s'::uuid",
u.ColumnString, equalsOp(not), s.Value), nil
case astUUIDVar:
return basicSQLEquality(cfg, not, u, other), nil
default:
return "", xerrors.Errorf("unsupported equality: %T %s %T",
u, equalsOp(not), other)
}
}
// ContainedInSQL implements SupportsContainedIn so that a UUID column
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
// array elements are rego strings, so we cast each to ::uuid.
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
arr, ok := haystack.(ASTArray)
if !ok {
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
}
if len(arr.Value) == 0 {
return "false", nil
}
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
values := make([]string, 0, len(arr.Value))
for _, v := range arr.Value {
s, ok := v.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString array element, got %T", v)
}
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
}
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
u.ColumnString,
strings.Join(values, ",")), nil
}
+1 -2
View File
@@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G
}
// Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams.
// nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters.
// nolint:exhaustruct // UserID is not obtained from the query parameters.
countFilter := database.CountAuditLogsParams{
RequestID: filter.RequestID,
ResourceID: filter.ResourceID,
@@ -123,7 +123,6 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey
}
// This MUST be kept in sync with the above
// nolint:exhaustruct // CountCap is not obtained from the query parameters.
countFilter := database.CountConnectionLogsParams{
OrganizationID: filter.OrganizationID,
WorkspaceOwner: filter.WorkspaceOwner,
+17 -35
View File
@@ -19,7 +19,6 @@ import (
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"
@@ -390,7 +389,6 @@ type MultiAgentController struct {
// connections to the destination
tickets map[uuid.UUID]map[uuid.UUID]struct{}
coordination *tailnet.BasicCoordination
sendGroup singleflight.Group
cancel context.CancelFunc
expireOldAgentsDone chan struct{}
@@ -420,44 +418,28 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.connectionTimes[agentID]
if ok {
m.connectionTimes[agentID] = time.Now()
m.mu.Unlock()
return nil
}
m.mu.Unlock()
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
_, err, _ := m.sendGroup.Do(agentID.String(), func() (interface{}, error) {
m.mu.Lock()
coord := m.coordination
m.mu.Unlock()
if coord == nil {
return nil, xerrors.New("no active coordination")
// If we don't have the agent, subscribe.
if !ok {
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
if m.coordination != nil {
err := m.coordination.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
err = xerrors.Errorf("subscribe agent: %w", err)
m.coordination.SendErr(err)
_ = m.coordination.Client.Close()
m.coordination = nil
return err
}
}
err := coord.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
return nil, err
}
m.mu.Lock()
m.tickets[agentID] = map[uuid.UUID]struct{}{}
m.mu.Unlock()
return nil, nil
})
if err != nil {
m.logger.Error(context.Background(), "ensureAgent send failed",
slog.F("agent_id", agentID), slog.Error(err))
return xerrors.Errorf("send AddTunnel: %w", err)
}
m.mu.Lock()
m.connectionTimes[agentID] = time.Now()
m.mu.Unlock()
return nil
}
+8 -13
View File
@@ -730,10 +730,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
if !ok {
return
}
log := s.Logger.With(
slog.F("agent_id", appToken.AgentID),
slog.F("workspace_id", appToken.WorkspaceID),
)
log := s.Logger.With(slog.F("agent_id", appToken.AgentID))
log.Debug(ctx, "resolved PTY request")
values := r.URL.Query()
@@ -768,21 +765,19 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
})
return
}
go httpapi.HeartbeatClose(ctx, s.Logger, cancel, conn)
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
go httpapi.HeartbeatClose(ctx, log, cancel, conn)
dialStart := time.Now()
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
if err != nil {
log.Debug(ctx, "dial workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "dial workspace agent", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return
}
defer release()
log.Debug(ctx, "dialed workspace agent", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "dialed workspace agent")
// #nosec G115 - Safe conversion for terminal height/width which are expected to be within uint16 range (0-65535)
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
arp.Container = container
@@ -790,12 +785,12 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
arp.BackendType = backendType
})
if err != nil {
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return
}
defer ptNetConn.Close()
log.Debug(ctx, "obtained PTY", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "obtained PTY")
report := newStatsReportFromSignedToken(*appToken)
s.collectStats(report)
@@ -805,7 +800,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
}()
agentssh.Bicopy(ctx, wsNetConn, ptNetConn)
log.Debug(ctx, "pty Bicopy finished", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "pty Bicopy finished")
}
func (s *Server) collectStats(stats StatsReport) {
+28 -124
View File
@@ -7,7 +7,6 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"net/http"
"slices"
"strconv"
@@ -152,12 +151,6 @@ type Server struct {
inFlightChatStaleAfter time.Duration
chatHeartbeatInterval time.Duration
// heartbeatMu guards heartbeatRegistry.
heartbeatMu sync.Mutex
// heartbeatRegistry maps chat IDs to their cancel functions
// and workspace state for the centralized heartbeat loop.
heartbeatRegistry map[uuid.UUID]*heartbeatEntry
// wakeCh is signaled by SendMessage, EditMessage, CreateChat,
// and PromoteQueued so the run loop calls processOnce
// immediately instead of waiting for the next ticker.
@@ -713,17 +706,6 @@ type chatStreamState struct {
bufferRetainedAt time.Time
}
// heartbeatEntry tracks a single chat's cancel function and workspace
// state for the centralized heartbeat loop. Instead of spawning a
// per-chat goroutine, processChat registers an entry here and the
// single heartbeatLoop goroutine handles all chats.
type heartbeatEntry struct {
cancelWithCause context.CancelCauseFunc
chatID uuid.UUID
workspaceID uuid.NullUUID
logger slog.Logger
}
// resetDropCounters zeroes the rate-limiting state for both buffer
// and subscriber drop warnings. The caller must hold s.mu.
func (s *chatStreamState) resetDropCounters() {
@@ -2438,8 +2420,8 @@ func New(cfg Config) *Server {
clock: clk,
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
wakeCh: make(chan struct{}, 1),
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
//nolint:gocritic // The chat processor uses a scoped chatd context.
ctx = dbauthz.AsChatd(ctx)
@@ -2479,9 +2461,6 @@ func (p *Server) start(ctx context.Context) {
// to handle chats orphaned by crashed or redeployed workers.
p.recoverStaleChats(ctx)
// Single heartbeat loop for all chats on this replica.
go p.heartbeatLoop(ctx)
acquireTicker := p.clock.NewTicker(
p.pendingChatAcquireInterval,
"chatd",
@@ -2751,97 +2730,6 @@ func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) {
p.workspaceMCPToolsCache.Delete(chatID)
}
// registerHeartbeat enrolls a chat in the centralized batch
// heartbeat loop. Must be called after chatCtx is created.
func (p *Server) registerHeartbeat(entry *heartbeatEntry) {
p.heartbeatMu.Lock()
defer p.heartbeatMu.Unlock()
if _, exists := p.heartbeatRegistry[entry.chatID]; exists {
p.logger.Warn(context.Background(),
"duplicate heartbeat registration, skipping",
slog.F("chat_id", entry.chatID))
return
}
p.heartbeatRegistry[entry.chatID] = entry
}
// unregisterHeartbeat removes a chat from the centralized
// heartbeat loop when chat processing finishes.
func (p *Server) unregisterHeartbeat(chatID uuid.UUID) {
p.heartbeatMu.Lock()
defer p.heartbeatMu.Unlock()
delete(p.heartbeatRegistry, chatID)
}
// heartbeatLoop runs in a single goroutine, issuing one batch
// heartbeat query per interval for all registered chats.
func (p *Server) heartbeatLoop(ctx context.Context) {
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "batch-heartbeat")
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.heartbeatTick(ctx)
}
}
}
// heartbeatTick issues a single batch UPDATE for all running chats
// owned by this worker. Chats missing from the result set are
// interrupted (stolen by another replica or already completed).
func (p *Server) heartbeatTick(ctx context.Context) {
// Snapshot the registry under the lock.
p.heartbeatMu.Lock()
snapshot := maps.Clone(p.heartbeatRegistry)
p.heartbeatMu.Unlock()
if len(snapshot) == 0 {
return
}
// Collect the IDs we believe we own.
ids := slices.Collect(maps.Keys(snapshot))
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
// access for batch-updating heartbeats.
chatdCtx := dbauthz.AsChatd(ctx)
updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{
IDs: ids,
WorkerID: p.workerID,
Now: p.clock.Now(),
})
if err != nil {
p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err))
return
}
// Build a set of IDs that were successfully updated.
updated := make(map[uuid.UUID]struct{}, len(updatedIDs))
for _, id := range updatedIDs {
updated[id] = struct{}{}
}
// Interrupt registered chats that were not in the result
// (stolen by another replica or already completed).
for id, entry := range snapshot {
if _, ok := updated[id]; !ok {
entry.logger.Warn(ctx, "chat not in batch heartbeat result, interrupting")
entry.cancelWithCause(chatloop.ErrInterrupted)
continue
}
// Bump workspace usage for surviving chats.
newWsID := p.trackWorkspaceUsage(ctx, entry.chatID, entry.workspaceID, entry.logger)
// Update workspace ID in the registry for next tick.
p.heartbeatMu.Lock()
if current, exists := p.heartbeatRegistry[id]; exists {
current.workspaceID = newWsID
}
p.heartbeatMu.Unlock()
}
}
func (p *Server) Subscribe(
ctx context.Context,
chatID uuid.UUID,
@@ -3687,17 +3575,33 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
}
}()
// Register with the centralized heartbeat loop instead of
// running a per-chat goroutine. The loop issues a single batch
// UPDATE for all chats on this worker and detects stolen chats
// via set-difference.
p.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel,
chatID: chat.ID,
workspaceID: chat.WorkspaceID,
logger: logger,
})
defer p.unregisterHeartbeat(chat.ID)
// Periodically update the heartbeat so other replicas know this
// worker is still alive. The goroutine stops when chatCtx is
// canceled (either by completion or interruption).
go func() {
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "heartbeat")
defer ticker.Stop()
for {
select {
case <-chatCtx.Done():
return
case <-ticker.C:
rows, err := p.db.UpdateChatHeartbeat(chatCtx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: p.workerID,
})
if err != nil {
logger.Warn(chatCtx, "failed to update chat heartbeat", slog.Error(err))
continue
}
if rows == 0 {
cancel(chatloop.ErrInterrupted)
return
}
chat.WorkspaceID = p.trackWorkspaceUsage(chatCtx, chat.ID, chat.WorkspaceID, logger)
}
}
}()
// Start buffering stream events BEFORE publishing the running
// status. This closes a race where a subscriber sees
-129
View File
@@ -21,7 +21,6 @@ import (
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
@@ -2072,7 +2071,6 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
workerID: workerID,
chatHeartbeatInterval: time.Minute,
configCache: newChatConfigCache(ctx, db, clock),
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
// Publish a stale "pending" notification on the control channel
@@ -2135,130 +2133,3 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
require.Equal(t, database.ChatStatusError, finalStatus,
"processChat should have reached runChat (error), not been interrupted (waiting)")
}
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
// batch heartbeat UPDATE does not return a registered chat's ID
// (because another replica stole it or it was completed), the
// heartbeat tick cancels that chat's context with ErrInterrupted
// while leaving surviving chats untouched.
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
workerID := uuid.New()
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: workerID,
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
// Create three chats with independent cancel functions.
chat1 := uuid.New()
chat2 := uuid.New()
chat3 := uuid.New()
_, cancel1 := context.WithCancelCause(ctx)
_, cancel2 := context.WithCancelCause(ctx)
ctx3, cancel3 := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel1,
chatID: chat1,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel2,
chatID: chat2,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel3,
chatID: chat3,
logger: logger,
})
// The batch UPDATE returns only chat1 and chat2 —
// chat3 was "stolen" by another replica.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
require.Equal(t, workerID, params.WorkerID)
require.Len(t, params.IDs, 3)
// Return only chat1 and chat2 as surviving.
return []uuid.UUID{chat1, chat2}, nil
},
)
server.heartbeatTick(ctx)
// chat3's context should be canceled with ErrInterrupted.
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
"stolen chat should be interrupted")
// chat3 should have been removed from the registry by
// unregister (in production this happens via defer in
// processChat). The heartbeat tick itself does not
// unregister — it only cancels. Verify the entry is
// still present (processChat's defer would clean it up).
server.heartbeatMu.Lock()
_, chat1Exists := server.heartbeatRegistry[chat1]
_, chat2Exists := server.heartbeatRegistry[chat2]
_, chat3Exists := server.heartbeatRegistry[chat3]
server.heartbeatMu.Unlock()
require.True(t, chat1Exists, "surviving chat1 should remain registered")
require.True(t, chat2Exists, "surviving chat2 should remain registered")
require.True(t, chat3Exists,
"stolen chat3 should still be in registry (processChat defer removes it)")
}
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
// transient database failure causes the tick to log and return
// without canceling any registered chats.
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: uuid.New(),
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
chatID := uuid.New()
chatCtx, cancel := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel,
chatID: chatID,
logger: logger,
})
// Simulate a transient DB error.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
nil, xerrors.New("connection reset"),
)
server.heartbeatTick(ctx)
// Chat should NOT be interrupted — the tick logged and
// returned early.
require.NoError(t, chatCtx.Err(),
"chat context should not be canceled on transient DB error")
}
+7 -12
View File
@@ -474,7 +474,7 @@ func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
}
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
@@ -501,24 +501,19 @@ func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
})
require.NoError(t, err)
// Wrong worker_id should return no IDs.
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: uuid.New(),
Now: time.Now(),
})
require.NoError(t, err)
require.Empty(t, ids)
require.Equal(t, int64(0), rows)
// Correct worker_id should return the chat's ID.
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: workerID,
Now: time.Now(),
})
require.NoError(t, err)
require.Len(t, ids, 1)
require.Equal(t, chat.ID, ids[0])
require.Equal(t, int64(1), rows)
}
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
+5 -34
View File
@@ -9,7 +9,6 @@ import (
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"sync"
"time"
@@ -50,11 +49,10 @@ const connectTimeout = 10 * time.Second
const toolCallTimeout = 60 * time.Second
// ConnectAll connects to all configured MCP servers, discovers
// their tools, and returns them as fantasy.AgentTool values.
// Tools are sorted by their prefixed name so callers
// receive a deterministic order. It skips servers that fail to
// connect and logs warnings. The returned cleanup function
// must be called to close all connections.
// their tools, and returns them as fantasy.AgentTool values. It
// skips servers that fail to connect and logs warnings. The
// returned cleanup function must be called to close all
// connections.
func ConnectAll(
ctx context.Context,
logger slog.Logger,
@@ -110,9 +108,7 @@ func ConnectAll(
}
mu.Lock()
if mcpClient != nil {
clients = append(clients, mcpClient)
}
clients = append(clients, mcpClient)
tools = append(tools, serverTools...)
mu.Unlock()
return nil
@@ -123,31 +119,6 @@ func ConnectAll(
// discarded.
_ = eg.Wait()
// Sort tools by prefixed name for deterministic ordering
// regardless of goroutine completion order. Ties, possible
// when the __ separator produces ambiguous prefixed names,
// are broken by config ID. Stable prompt construction
// depends on consistent tool ordering.
slices.SortFunc(tools, func(a, b fantasy.AgentTool) int {
// All tools in this slice are mcpToolWrapper values
// created by connectOne above, so these checked
// assertions should always succeed. The config ID
// tiebreaker resolves the __ separator ambiguity
// documented at the top of this file.
aTool, ok := a.(MCPToolIdentifier)
if !ok {
panic(fmt.Sprintf("unexpected tool type %T", a))
}
bTool, ok := b.(MCPToolIdentifier)
if !ok {
panic(fmt.Sprintf("unexpected tool type %T", b))
}
return cmp.Or(
cmp.Compare(a.Info().Name, b.Info().Name),
cmp.Compare(aTool.MCPServerConfigID().String(), bTool.MCPServerConfigID().String()),
)
})
return tools, cleanup
}
-126
View File
@@ -63,17 +63,6 @@ func greetTool() mcpserver.ServerTool {
}
}
// makeTool returns a ServerTool with the given name and a
// no-op handler that always returns "ok".
func makeTool(name string) mcpserver.ServerTool {
return mcpserver.ServerTool{
Tool: mcp.NewTool(name),
Handler: func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText("ok"), nil
},
}
}
// makeConfig builds a database.MCPServerConfig suitable for tests.
func makeConfig(slug, url string) database.MCPServerConfig {
return database.MCPServerConfig{
@@ -209,121 +198,6 @@ func TestConnectAll_MultipleServers(t *testing.T) {
assert.Contains(t, names, "beta__greet")
}
func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("filtered", ts.URL)
cfg.ToolAllowList = []string{"greet"}
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{cfg},
nil,
)
require.Empty(t, tools)
assert.NotPanics(t, cleanup)
}
func TestConnectAll_DeterministicOrder(t *testing.T) {
t.Parallel()
t.Run("AcrossServers", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts1 := newTestMCPServer(t, makeTool("zebra"))
ts2 := newTestMCPServer(t, makeTool("alpha"))
ts3 := newTestMCPServer(t, makeTool("middle"))
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{
makeConfig("srv3", ts3.URL),
makeConfig("srv1", ts1.URL),
makeConfig("srv2", ts2.URL),
},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 3)
// Sorted by full prefixed name (slug__tool), so slug
// order determines the sequence, not the tool name.
assert.Equal(t,
[]string{"srv1__zebra", "srv2__alpha", "srv3__middle"},
toolNames(tools),
)
})
t.Run("WithMultiToolServer", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
multi := newTestMCPServer(t, makeTool("zeta"), makeTool("beta"))
other := newTestMCPServer(t, makeTool("gamma"))
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{
makeConfig("zzz", multi.URL),
makeConfig("aaa", other.URL),
},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 3)
assert.Equal(t,
[]string{"aaa__gamma", "zzz__beta", "zzz__zeta"},
toolNames(tools),
)
})
t.Run("TiebreakByConfigID", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts1 := newTestMCPServer(t, makeTool("b__z"))
ts2 := newTestMCPServer(t, makeTool("z"))
// Use fixed UUIDs so the tiebreaker order is
// predictable. Both servers produce the same prefixed
// name, a__b__z, due to the __ separator ambiguity.
cfg1 := makeConfig("a", ts1.URL)
cfg1.ID = uuid.MustParse("00000000-0000-0000-0000-000000000002")
cfg2 := makeConfig("a__b", ts2.URL)
cfg2.ID = uuid.MustParse("00000000-0000-0000-0000-000000000001")
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{cfg1, cfg2},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 2)
assert.Equal(t, []string{"a__b__z", "a__b__z"}, toolNames(tools))
id0 := tools[0].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
id1 := tools[1].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
assert.Equal(t, cfg2.ID, id0, "lower config ID should sort first")
assert.Equal(t, cfg1.ID, id1, "higher config ID should sort second")
})
}
func TestConnectAll_AuthHeaders(t *testing.T) {
t.Parallel()
ctx := context.Background()
-1
View File
@@ -212,7 +212,6 @@ type AuditLogsRequest struct {
type AuditLogResponse struct {
AuditLogs []AuditLog `json:"audit_logs"`
Count int64 `json:"count"`
CountCap int64 `json:"count_cap"`
}
type CreateTestAuditLogRequest struct {
-1
View File
@@ -96,7 +96,6 @@ type ConnectionLogsRequest struct {
type ConnectionLogResponse struct {
ConnectionLogs []ConnectionLog `json:"connection_logs"`
Count int64 `json:"count"`
CountCap int64 `json:"count_cap"`
}
func (c *Client) ConnectionLogs(ctx context.Context, req ConnectionLogsRequest) (ConnectionLogResponse, error) {
+1 -2
View File
@@ -90,8 +90,7 @@ curl -X GET http://coder-server:8080/api/v2/audit?limit=0 \
"user_agent": "string"
}
],
"count": 0,
"count_cap": 0
"count": 0
}
```
+1 -2
View File
@@ -291,8 +291,7 @@ curl -X GET http://coder-server:8080/api/v2/connectionlog?limit=0 \
"workspace_owner_username": "string"
}
],
"count": 0,
"count_cap": 0
"count": 0
}
```
+2 -6
View File
@@ -1740,8 +1740,7 @@
"user_agent": "string"
}
],
"count": 0,
"count_cap": 0
"count": 0
}
```
@@ -1751,7 +1750,6 @@
|--------------|-------------------------------------------------|----------|--------------|-------------|
| `audit_logs` | array of [codersdk.AuditLog](#codersdkauditlog) | false | | |
| `count` | integer | false | | |
| `count_cap` | integer | false | | |
## codersdk.AuthMethod
@@ -2175,8 +2173,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
"workspace_owner_username": "string"
}
],
"count": 0,
"count_cap": 0
"count": 0
}
```
@@ -2186,7 +2183,6 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|-------------------|-----------------------------------------------------------|----------|--------------|-------------|
| `connection_logs` | array of [codersdk.ConnectionLog](#codersdkconnectionlog) | false | | |
| `count` | integer | false | | |
| `count_cap` | integer | false | | |
## codersdk.ConnectionLogSSHInfo
-6
View File
@@ -16,9 +16,6 @@ import (
"github.com/coder/coder/v2/codersdk"
)
// NOTE: See the auditLogCountCap note.
const connectionLogCountCap = 2000
// @Summary Get connection logs
// @ID get-connection-logs
// @Security CoderSessionToken
@@ -52,7 +49,6 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
// #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range
filter.LimitOpt = int32(page.Limit)
countFilter.CountCap = connectionLogCountCap
count, err := api.Database.CountConnectionLogs(ctx, countFilter)
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
@@ -67,7 +63,6 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
ConnectionLogs: []codersdk.ConnectionLog{},
Count: 0,
CountCap: connectionLogCountCap,
})
return
}
@@ -85,7 +80,6 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
ConnectionLogs: convertConnectionLogs(dblogs),
Count: count,
CountCap: connectionLogCountCap,
})
}
+7 -27
View File
@@ -12,7 +12,6 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
gProto "google.golang.org/protobuf/proto"
@@ -34,9 +33,9 @@ const (
eventReadyForHandshake = "tailnet_ready_for_handshake"
HeartbeatPeriod = time.Second * 2
MissedHeartbeats = 3
numQuerierWorkers = 40
numQuerierWorkers = 10
numBinderWorkers = 10
numTunnelerWorkers = 20
numTunnelerWorkers = 10
numHandshakerWorkers = 5
dbMaxBackoff = 10 * time.Second
cleanupPeriod = time.Hour
@@ -771,9 +770,6 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo
for k := range m.sent {
if _, ok := best[k]; !ok {
m.logger.Debug(m.ctx, "peer no longer in best mappings, sending DISCONNECTED",
slog.F("peer_id", k),
)
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
Id: agpl.UUIDToByteSlice(k),
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
@@ -824,8 +820,6 @@ type querier struct {
mu sync.Mutex
mappers map[mKey]*mapper
healthy bool
resyncGroup singleflight.Group
}
func newQuerier(ctx context.Context,
@@ -964,7 +958,7 @@ func (q *querier) cleanupConn(c *connIO) {
// maxBatchSize is the maximum number of keys to process in a single batch
// query.
const maxBatchSize = 200
const maxBatchSize = 50
func (q *querier) peerUpdateWorker() {
defer q.wg.Done()
@@ -1213,13 +1207,8 @@ func (q *querier) subscribe() {
func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
q.logger.Warn(q.ctx, "pubsub may have dropped peer updates")
// Schedule a full resync asynchronously so we don't block the
// pubsub drain goroutine. Singleflight coalesces concurrent
// resync requests.
go q.resyncGroup.Do("resync", func() (any, error) {
q.resyncPeerMappings()
return nil, nil
})
// we need to schedule a full resync of peer mappings
q.resyncPeerMappings()
return
}
if err != nil {
@@ -1245,13 +1234,8 @@ func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
q.logger.Warn(q.ctx, "pubsub may have dropped tunnel updates")
// Schedule a full resync asynchronously so we don't block the
// pubsub drain goroutine. Singleflight coalesces concurrent
// resync requests.
go q.resyncGroup.Do("resync", func() (any, error) {
q.resyncPeerMappings()
return nil, nil
})
// we need to schedule a full resync of peer mappings
q.resyncPeerMappings()
return
}
if err != nil {
@@ -1617,10 +1601,6 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
// the only mapping available for it. Newer mappings will take
// precedence.
m.kind = proto.CoordinateResponse_PeerUpdate_LOST
h.logger.Debug(h.ctx, "mapping rewritten to LOST due to missed heartbeats",
slog.F("peer_id", m.peer),
slog.F("coordinator_id", m.coordinator),
)
}
}
+2 -2
View File
@@ -76,11 +76,11 @@ replace github.com/aquasecurity/trivy => github.com/coder/trivy v0.0.0-202603091
// https://github.com/spf13/afero/pull/487
replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696
// Forked from coder/fantasy (cj/go1.25 branch) which adds:
// Forked from kylecarbs/fantasy (cj/go1.25 branch) which adds:
// 1) Anthropic computer use + thinking effort
// 2) Go 1.25 downgrade for Windows CI compat
// 3) ibetitsmike/fantasy#4 skip ephemeral replay items when store=false
replace charm.land/fantasy => github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8
replace github.com/charmbracelet/anthropic-sdk-go => github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab
+2 -2
View File
@@ -322,8 +322,6 @@ github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41 h1:SBN/DA63+ZHwu
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41/go.mod h1:I9ULxr64UaOSUv7hcb3nX4kowodJCVS7vt7VVJk/kW4=
github.com/coder/clistat v1.2.1 h1:P9/10njXMyj5cWzIU5wkRsSy5LVQH49+tcGMsAgWX0w=
github.com/coder/clistat v1.2.1/go.mod h1:m7SC0uj88eEERgvF8Kn6+w6XF21BeSr+15f7GoLAw0A=
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:n+6v+yT1B6V4oSGPmXFh7mul1E+RzG9rnqp50Vb7M/w=
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
github.com/coder/flog v1.1.0 h1:kbAes1ai8fIS5OeV+QAnKBQE22ty1jRF/mcAwHpLBa4=
github.com/coder/flog v1.1.0/go.mod h1:UQlQvrkJBvnRGo69Le8E24Tcl5SJleAAR7gYEHzAmdQ=
github.com/coder/go-httpstat v0.0.0-20230801153223-321c88088322 h1:m0lPZjlQ7vdVpRBPKfYIFlmgevoTkBxB10wv6l2gOaU=
@@ -815,6 +813,8 @@ github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:5UMY
github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab/go.mod h1:hqlYqR7uPKOKfnNeicUbZp0Ps0GeYFlKYtwh5HGDCx8=
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3 h1:Z9/bo5PSeMutpdiKYNt/TTSfGM1Ll0naj3QzYX9VxTc=
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3/go.mod h1:BUGjjsD+ndS6eX37YgTchSEG+Jg9Jv1GiZs9sqPqztk=
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:fZ0208U3B438fDSHCc/GNioPIyaFqn6eBsQTO61QtrI=
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae h1:xlFZNX4nnxpj9Cf6mTwD3pirXGNtBJ/6COsf9iZmsL0=
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e h1:OP0ZMFeZkUnOzTFRfpuK3m7Kp4fNvC6qN+exwj7aI4M=
+35 -44
View File
@@ -68,17 +68,17 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
return xerrors.Errorf("detecting branch: %w", err)
}
// Match release branches (release/X.Y). RCs are tagged
// from main, not from release branches.
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)$`)
// Match standard release branches (release/2.32) and RC
// branches (release/2.32-rc.0).
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)(?:-rc\.(\d+))?$`)
m := branchRe.FindStringSubmatch(currentBranch)
if m == nil {
warnf(w, "Current branch %q is not a release branch (release/X.Y).", currentBranch)
warnf(w, "Current branch %q is not a release branch (release/X.Y or release/X.Y-rc.N).", currentBranch)
branchInput, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Enter the release branch to use (e.g. release/2.21)",
Text: "Enter the release branch to use (e.g. release/2.21 or release/2.21-rc.0)",
Validate: func(s string) error {
if !branchRe.MatchString(s) {
return xerrors.New("must be in format release/X.Y (e.g. release/2.21)")
return xerrors.New("must be in format release/X.Y or release/X.Y-rc.N (e.g. release/2.21 or release/2.21-rc.0)")
}
return nil
},
@@ -91,6 +91,10 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
}
branchMajor, _ := strconv.Atoi(m[1])
branchMinor, _ := strconv.Atoi(m[2])
branchRC := -1 // -1 means not an RC branch.
if m[3] != "" {
branchRC, _ = strconv.Atoi(m[3])
}
successf(w, "Using release branch: %s", currentBranch)
// --- Fetch & sync check ---
@@ -134,44 +138,31 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
}
}
// changelogBaseRef is the git ref used as the starting point
// for release notes generation. When a tag already exists in
// this minor series we use it directly. For the first release
// on a new minor no matching tag exists, so we compute the
// merge-base with the previous minor's release branch instead.
// This works even when that branch has no tags yet (it was
// just cut and pushed). As a last resort we fall back to the
// latest reachable tag from a previous minor.
var changelogBaseRef string
if prevVersion != nil {
changelogBaseRef = prevVersion.String()
} else {
prevReleaseBranch := fmt.Sprintf("release/%d.%d", branchMajor, branchMinor-1)
if err := gitRun("fetch", "--quiet", "origin", prevReleaseBranch); err != nil {
warnf(w, "Could not fetch %s: %v", prevReleaseBranch, err)
}
if mb, mbErr := gitOutput("merge-base", "HEAD", "origin/"+prevReleaseBranch); mbErr == nil && mb != "" {
changelogBaseRef = mb
infof(w, "Using merge-base with %s as changelog base: %s", prevReleaseBranch, mb[:12])
} else {
// No previous release branch found; fall back to
// the latest reachable tag from a previous minor.
for _, t := range mergedTags {
if t.Major == branchMajor && t.Minor < branchMinor {
changelogBaseRef = t.String()
break
}
}
}
}
var suggested version
if prevVersion == nil {
infof(w, "No previous release tag found on this branch.")
suggested = version{Major: branchMajor, Minor: branchMinor, Patch: 0}
if branchRC >= 0 {
suggested.Pre = fmt.Sprintf("rc.%d", branchRC)
}
} else {
infof(w, "Previous release tag: %s", prevVersion.String())
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
if branchRC >= 0 {
// On an RC branch, suggest the next RC for
// the same base version.
nextRC := 0
if prevVersion.IsRC() {
nextRC = prevVersion.rcNumber() + 1
}
suggested = version{
Major: prevVersion.Major,
Minor: prevVersion.Minor,
Patch: prevVersion.Patch,
Pre: fmt.Sprintf("rc.%d", nextRC),
}
} else {
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
}
}
fmt.Fprintln(w)
@@ -375,8 +366,8 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
infof(w, "Generating release notes...")
commitRange := "HEAD"
if changelogBaseRef != "" {
commitRange = changelogBaseRef + "..HEAD"
if prevVersion != nil {
commitRange = prevVersion.String() + "..HEAD"
}
commits, err := commitLog(commitRange)
@@ -482,16 +473,16 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
}
if !hasContent {
prevStr := "the beginning of time"
if changelogBaseRef != "" {
prevStr = changelogBaseRef
if prevVersion != nil {
prevStr = prevVersion.String()
}
fmt.Fprintf(&notes, "\n_No changes since %s._\n", prevStr)
}
// Compare link.
if changelogBaseRef != "" {
if prevVersion != nil {
fmt.Fprintf(&notes, "\nCompare: [`%s...%s`](https://github.com/%s/%s/compare/%s...%s)\n",
changelogBaseRef, newVersion, owner, repo, changelogBaseRef, newVersion)
prevVersion, newVersion, owner, repo, prevVersion, newVersion)
}
// Container image.
-2
View File
@@ -913,7 +913,6 @@ export interface AuditLog {
export interface AuditLogResponse {
readonly audit_logs: readonly AuditLog[];
readonly count: number;
readonly count_cap: number;
}
// From codersdk/audit.go
@@ -2270,7 +2269,6 @@ export interface ConnectionLog {
export interface ConnectionLogResponse {
readonly connection_logs: readonly ConnectionLog[];
readonly count: number;
readonly count_cap: number;
}
// From codersdk/connectionlog.go
+2 -2
View File
@@ -35,10 +35,10 @@ export const AvatarData: FC<AvatarDataProps> = ({
}
return (
<div className="flex items-center w-full gap-3">
<div className="flex items-center w-full min-w-0 gap-3">
{avatar}
<div className="flex flex-col w-full">
<div className="flex flex-col w-full min-w-0">
<span className="text-sm font-semibold text-content-primary">
{title}
</span>
+57 -31
View File
@@ -1,6 +1,7 @@
import type { Interpolation, Theme } from "@emotion/react";
import type { FC, HTMLAttributes } from "react";
import type { LogLevel } from "#/api/typesGenerated";
import { cn } from "#/utils/cn";
import { MONOSPACE_FONT_FAMILY } from "#/theme/constants";
const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
@@ -16,40 +17,65 @@ type LogLineProps = {
level: LogLevel;
} & HTMLAttributes<HTMLPreElement>;
export const LogLine: FC<LogLineProps> = ({ level, className, ...props }) => {
export const LogLine: FC<LogLineProps> = ({ level, ...divProps }) => {
return (
<pre
{...props}
className={cn(
"logs-line",
"m-0 break-all flex items-center h-auto",
"text-[13px] text-content-primary font-mono",
level === "error" &&
"bg-surface-error text-content-error [&_.dashed-line]:bg-border-error",
level === "debug" &&
"bg-surface-sky text-content-sky [&_.dashed-line]:bg-border-sky",
level === "warn" &&
"bg-surface-warning text-content-warning [&_.dashed-line]:bg-border-warning",
className,
)}
style={{
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
}}
css={styles.line}
className={`${level} ${divProps.className} logs-line`}
{...divProps}
/>
);
};
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = ({
className,
...props
}) => {
return (
<pre
className={cn(
"select-none m-0 inline-block text-content-secondary mr-6",
className,
)}
{...props}
/>
);
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = (props) => {
return <pre css={styles.prefix} {...props} />;
};
const styles = {
line: (theme) => ({
margin: 0,
wordBreak: "break-all",
display: "flex",
alignItems: "center",
fontSize: 13,
color: theme.palette.text.primary,
fontFamily: MONOSPACE_FONT_FAMILY,
height: "auto",
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
"&.error": {
backgroundColor: theme.roles.error.background,
color: theme.roles.error.text,
"& .dashed-line": {
backgroundColor: theme.roles.error.outline,
},
},
"&.debug": {
backgroundColor: theme.roles.notice.background,
color: theme.roles.notice.text,
"& .dashed-line": {
backgroundColor: theme.roles.notice.outline,
},
},
"&.warn": {
backgroundColor: theme.roles.warning.background,
color: theme.roles.warning.text,
"& .dashed-line": {
backgroundColor: theme.roles.warning.outline,
},
},
}),
prefix: (theme) => ({
userSelect: "none",
margin: 0,
display: "inline-block",
color: theme.palette.text.secondary,
marginRight: 24,
}),
} satisfies Record<string, Interpolation<Theme>>;
+17 -12
View File
@@ -1,6 +1,6 @@
import type { Interpolation, Theme } from "@emotion/react";
import dayjs from "dayjs";
import type { FC } from "react";
import { cn } from "#/utils/cn";
import { type Line, LogLine, LogLinePrefix } from "./LogLine";
export const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
@@ -17,17 +17,7 @@ export const Logs: FC<LogsProps> = ({
className = "",
}) => {
return (
<div
className={cn(
"logs-container",
"min-h-40 py-2 rounded-lg overflow-x-auto bg-surface-primary",
"[&:not(:last-child)]:border-0",
"[&:not(:last-child)]:border-solid",
"[&:not(:last-child)]:border-b-border",
"[&:not(:last-child)]:rounded-none",
className,
)}
>
<div css={styles.root} className={`${className} logs-container`}>
<div className="min-w-fit">
{lines.map((line) => (
<LogLine key={line.id} level={line.level}>
@@ -43,3 +33,18 @@ export const Logs: FC<LogsProps> = ({
</div>
);
};
const styles = {
root: (theme) => ({
minHeight: 156,
padding: "8px 0",
borderRadius: 8,
overflowX: "auto",
background: theme.palette.background.default,
"&:not(:last-child)": {
borderBottom: `1px solid ${theme.palette.divider}`,
borderRadius: 0,
},
}),
} satisfies Record<string, Interpolation<Theme>>;
@@ -7,7 +7,6 @@ type PaginationHeaderProps = {
limit: number;
totalRecords: number | undefined;
currentOffsetStart: number | undefined;
countIsCapped?: boolean;
// Temporary escape hatch until Workspaces can be switched over to using
// PaginationContainer
@@ -19,7 +18,6 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
limit,
totalRecords,
currentOffsetStart,
countIsCapped,
className,
}) => {
const theme = useTheme();
@@ -54,16 +52,10 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
<strong>
{(
currentOffsetStart +
(countIsCapped
? limit - 1
: Math.min(limit - 1, totalRecords - currentOffsetStart))
Math.min(limit - 1, totalRecords - currentOffsetStart)
).toLocaleString()}
</strong>{" "}
of{" "}
<strong>
{totalRecords.toLocaleString()}
{countIsCapped && "+"}
</strong>{" "}
of <strong>{totalRecords.toLocaleString()}</strong>{" "}
{paginationUnitLabel}
</div>
)}
@@ -18,7 +18,6 @@ export const mockPaginationResultBase: ResultBase = {
limit: 25,
hasNextPage: false,
hasPreviousPage: false,
countIsCapped: false,
goToPreviousPage: () => {},
goToNextPage: () => {},
goToFirstPage: () => {},
@@ -34,7 +33,6 @@ export const mockInitialRenderResult: PaginationResult = {
hasPreviousPage: false,
totalRecords: undefined,
totalPages: undefined,
countIsCapped: false,
};
export const mockSuccessResult: PaginationResult = {
@@ -94,7 +94,7 @@ export const FirstPageWithTonsOfData: Story = {
currentPage: 2,
currentOffsetStart: 1000,
totalRecords: 123_456,
totalPages: 4939,
totalPages: 1235,
hasPreviousPage: false,
hasNextPage: true,
isPlaceholderData: false,
@@ -135,54 +135,3 @@ export const SecondPageWithData: Story = {
children: <div>New data for page 2</div>,
},
};
export const CappedCountFirstPage: Story = {
args: {
query: {
...mockPaginationResultBase,
isSuccess: true,
currentPage: 1,
currentOffsetStart: 1,
totalRecords: 2000,
totalPages: 80,
hasPreviousPage: false,
hasNextPage: true,
isPlaceholderData: false,
countIsCapped: true,
},
},
};
export const CappedCountMiddlePage: Story = {
args: {
query: {
...mockPaginationResultBase,
isSuccess: true,
currentPage: 3,
currentOffsetStart: 51,
totalRecords: 2000,
totalPages: 80,
hasPreviousPage: true,
hasNextPage: true,
isPlaceholderData: false,
countIsCapped: true,
},
},
};
export const CappedCountBeyondKnownPages: Story = {
args: {
query: {
...mockPaginationResultBase,
isSuccess: true,
currentPage: 85,
currentOffsetStart: 2101,
totalRecords: 2000,
totalPages: 85,
hasPreviousPage: true,
hasNextPage: true,
isPlaceholderData: false,
countIsCapped: true,
},
},
};
@@ -27,14 +27,12 @@ export const PaginationContainer: FC<PaginationProps> = ({
totalRecords={query.totalRecords}
currentOffsetStart={query.currentOffsetStart}
paginationUnitLabel={paginationUnitLabel}
countIsCapped={query.countIsCapped}
className="justify-end"
/>
{query.isSuccess && (
<PaginationWidgetBase
totalRecords={query.totalRecords}
totalPages={query.totalPages}
currentPage={query.currentPage}
pageSize={query.limit}
onPageChange={query.onPageChange}
@@ -12,10 +12,6 @@ export type PaginationWidgetBaseProps = {
hasPreviousPage?: boolean;
hasNextPage?: boolean;
/** Override the computed totalPages.
* Used when, e.g., the row count is capped and the user navigates beyond
* the known range, so totalPages stays at least as high as currentPage. */
totalPages?: number;
};
export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
@@ -25,9 +21,8 @@ export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
onPageChange,
hasPreviousPage,
hasNextPage,
totalPages: totalPagesProp,
}) => {
const totalPages = totalPagesProp ?? Math.ceil(totalRecords / pageSize);
const totalPages = Math.ceil(totalRecords / pageSize);
if (totalPages < 2) {
return null;
+1
View File
@@ -0,0 +1 @@
export { useTabOverflowKebabMenu } from "./useTabOverflowKebabMenu";
@@ -1,274 +0,0 @@
import {
type RefObject,
useCallback,
useEffect,
useRef,
useState,
} from "react";
type TabValue = {
value: string;
};
type UseKebabMenuOptions<T extends TabValue> = {
tabs: readonly T[];
enabled: boolean;
isActive: boolean;
overflowTriggerWidth?: number;
};
type UseKebabMenuResult<T extends TabValue> = {
containerRef: RefObject<HTMLDivElement | null>;
visibleTabs: T[];
overflowTabs: T[];
getTabMeasureProps: (tabValue: string) => Record<string, string>;
};
const ALWAYS_VISIBLE_TABS_COUNT = 1;
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
/**
* Splits tabs into visible and overflow groups based on container width.
*
* Tabs must render with `getTabMeasureProps()` so this hook can measure
* trigger widths from the DOM.
*/
export const useKebabMenu = <T extends TabValue>({
tabs,
enabled,
isActive,
overflowTriggerWidth = 44,
}: UseKebabMenuOptions<T>): UseKebabMenuResult<T> => {
const containerRef = useRef<HTMLDivElement>(null);
const tabsRef = useRef<readonly T[]>(tabs);
tabsRef.current = tabs;
const previousTabsRef = useRef<readonly T[]>(tabs);
const availableWidthRef = useRef<number | null>(null);
// Width cache prevents oscillation when overflow tabs are not mounted.
const tabWidthByValueRef = useRef<Record<string, number>>({});
const [overflowTabValues, setTabValues] = useState<string[]>([]);
const recalculateOverflow = useCallback(
(availableWidth: number) => {
if (!enabled || !isActive) {
// Keep this update idempotent to avoid render loops.
setTabValues((currentValues) => {
if (currentValues.length === 0) {
return currentValues;
}
return [];
});
return;
}
const container = containerRef.current;
if (!container) {
return;
}
const currentTabs = tabsRef.current;
const tabWidthByValue = measureTabWidths({
tabs: currentTabs,
container,
previousTabWidthByValue: tabWidthByValueRef.current,
});
tabWidthByValueRef.current = tabWidthByValue;
const nextOverflowValues = calculateTabValues({
tabs: currentTabs,
availableWidth,
tabWidthByValue,
overflowTriggerWidth,
});
setTabValues((currentValues) => {
// Avoid state updates when the computed overflow did not change.
if (areStringArraysEqual(currentValues, nextOverflowValues)) {
return currentValues;
}
return nextOverflowValues;
});
},
[enabled, isActive, overflowTriggerWidth],
);
useEffect(() => {
if (previousTabsRef.current === tabs) {
// No change in tabs, no need to recalculate.
return;
}
previousTabsRef.current = tabs;
if (availableWidthRef.current === null) {
// First mount, no width available yet.
return;
}
recalculateOverflow(availableWidthRef.current);
}, [recalculateOverflow, tabs]);
useEffect(() => {
const container = containerRef.current;
if (!container) {
return;
}
// Recompute whenever ResizeObserver reports a container width change.
const observer = new ResizeObserver(([entry]) => {
if (!entry) {
return;
}
availableWidthRef.current = entry.contentRect.width;
recalculateOverflow(entry.contentRect.width);
});
observer.observe(container);
return () => observer.disconnect();
}, [recalculateOverflow]);
const overflowTabValuesSet = new Set(overflowTabValues);
const { visibleTabs, overflowTabs } = tabs.reduce<{
visibleTabs: T[];
overflowTabs: T[];
}>(
(tabGroups, tab) => {
if (overflowTabValuesSet.has(tab.value)) {
tabGroups.overflowTabs.push(tab);
} else {
tabGroups.visibleTabs.push(tab);
}
return tabGroups;
},
{ visibleTabs: [], overflowTabs: [] },
);
const getTabMeasureProps = (tabValue: string) => {
return { [DATA_ATTR_TAB_VALUE]: tabValue };
};
return {
containerRef,
visibleTabs,
overflowTabs,
getTabMeasureProps,
};
};
const calculateTabValues = <T extends TabValue>({
tabs,
availableWidth,
tabWidthByValue,
overflowTriggerWidth,
}: {
tabs: readonly T[];
availableWidth: number;
tabWidthByValue: Readonly<Record<string, number>>;
overflowTriggerWidth: number;
}): string[] => {
const tabWidthByValueMap = new Map<string, number>();
for (const tab of tabs) {
tabWidthByValueMap.set(tab.value, tabWidthByValue[tab.value] ?? 0);
}
const firstOptionalTabIndex = Math.min(
ALWAYS_VISIBLE_TABS_COUNT,
tabs.length,
);
if (firstOptionalTabIndex >= tabs.length) {
return [];
}
const alwaysVisibleTabs = tabs.slice(0, firstOptionalTabIndex);
const optionalTabs = tabs.slice(firstOptionalTabIndex);
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
return total + (tabWidthByValueMap.get(tab.value) ?? 0);
}, 0);
const firstTabIndex = findFirstTabIndex({
optionalTabs,
optionalTabWidths: optionalTabs.map((tab) => {
return tabWidthByValueMap.get(tab.value) ?? 0;
}),
startingUsedWidth: alwaysVisibleWidth,
availableWidth,
overflowTriggerWidth,
});
if (firstTabIndex === -1) {
return [];
}
return optionalTabs
.slice(firstTabIndex)
.map((overflowTab) => overflowTab.value);
};
const measureTabWidths = <T extends TabValue>({
tabs,
container,
previousTabWidthByValue,
}: {
tabs: readonly T[];
container: HTMLDivElement;
previousTabWidthByValue: Readonly<Record<string, number>>;
}): Record<string, number> => {
const nextTabWidthByValue = { ...previousTabWidthByValue };
for (const tab of tabs) {
const tabElement = container.querySelector<HTMLElement>(
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
);
if (tabElement) {
nextTabWidthByValue[tab.value] = tabElement.offsetWidth;
}
}
return nextTabWidthByValue;
};
const findFirstTabIndex = ({
optionalTabs,
optionalTabWidths,
startingUsedWidth,
availableWidth,
overflowTriggerWidth,
}: {
optionalTabs: readonly TabValue[];
optionalTabWidths: readonly number[];
startingUsedWidth: number;
availableWidth: number;
overflowTriggerWidth: number;
}): number => {
const result = optionalTabs.reduce(
(acc, _tab, index) => {
if (acc.firstTabIndex !== -1) {
return acc;
}
const tabWidth = optionalTabWidths[index] ?? 0;
const hasMoreTabs = index < optionalTabs.length - 1;
// Reserve kebab trigger width whenever additional tabs remain.
const widthNeeded =
acc.usedWidth + tabWidth + (hasMoreTabs ? overflowTriggerWidth : 0);
if (widthNeeded <= availableWidth) {
return {
usedWidth: acc.usedWidth + tabWidth,
firstTabIndex: -1,
};
}
return {
usedWidth: acc.usedWidth,
firstTabIndex: index,
};
},
{ usedWidth: startingUsedWidth, firstTabIndex: -1 },
);
return result.firstTabIndex;
};
const areStringArraysEqual = (
left: readonly string[],
right: readonly string[],
): boolean => {
return (
left.length === right.length &&
left.every((value, index) => value === right[index])
);
};
@@ -0,0 +1,157 @@
import {
type RefObject,
useCallback,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from "react";
type TabLike = {
value: string;
};
type UseTabOverflowKebabMenuOptions<TTab extends TabLike> = {
tabs: readonly TTab[];
enabled: boolean;
isActive: boolean;
alwaysVisibleTabsCount?: number;
overflowTriggerWidthPx?: number;
};
type UseTabOverflowKebabMenuResult<TTab extends TabLike> = {
containerRef: RefObject<HTMLDivElement | null>;
visibleTabs: TTab[];
overflowTabs: TTab[];
getTabMeasureProps: (tabValue: string) => Record<string, string>;
};
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
export const useTabOverflowKebabMenu = <TTab extends TabLike>({
tabs,
enabled,
isActive,
alwaysVisibleTabsCount = 1,
overflowTriggerWidthPx = 44,
}: UseTabOverflowKebabMenuOptions<TTab>): UseTabOverflowKebabMenuResult<TTab> => {
const containerRef = useRef<HTMLDivElement>(null);
const tabWidthByValueRef = useRef<Record<string, number>>({});
const [overflowTabValues, setOverflowTabValues] = useState<string[]>([]);
const recalculateOverflow = useCallback(() => {
if (!enabled) {
setOverflowTabValues([]);
return;
}
const container = containerRef.current;
if (!container) {
return;
}
for (const tab of tabs) {
const tabElement = container.querySelector<HTMLElement>(
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
);
if (tabElement) {
tabWidthByValueRef.current[tab.value] = tabElement.offsetWidth;
}
}
const alwaysVisibleTabs = tabs.slice(0, alwaysVisibleTabsCount);
const optionalTabs = tabs.slice(alwaysVisibleTabsCount);
if (optionalTabs.length === 0) {
setOverflowTabValues([]);
return;
}
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
return total + (tabWidthByValueRef.current[tab.value] ?? 0);
}, 0);
const availableWidth = container.clientWidth;
let usedWidth = alwaysVisibleWidth;
const nextOverflowValues: string[] = [];
for (let i = 0; i < optionalTabs.length; i++) {
const tab = optionalTabs[i];
const tabWidth = tabWidthByValueRef.current[tab.value] ?? 0;
const hasMoreTabsAfterCurrent = i < optionalTabs.length - 1;
const widthNeeded =
usedWidth +
tabWidth +
(hasMoreTabsAfterCurrent ? overflowTriggerWidthPx : 0);
if (widthNeeded <= availableWidth) {
usedWidth += tabWidth;
continue;
}
nextOverflowValues.push(
...optionalTabs.slice(i).map((overflowTab) => overflowTab.value),
);
break;
}
setOverflowTabValues((currentValues) => {
if (
currentValues.length === nextOverflowValues.length &&
currentValues.every(
(value, index) => value === nextOverflowValues[index],
)
) {
return currentValues;
}
return nextOverflowValues;
});
}, [alwaysVisibleTabsCount, enabled, overflowTriggerWidthPx, tabs]);
useLayoutEffect(() => {
if (!isActive) {
return;
}
recalculateOverflow();
}, [isActive, recalculateOverflow]);
useEffect(() => {
if (!isActive) {
return;
}
const container = containerRef.current;
if (!container) {
return;
}
const observer = new ResizeObserver(() => {
recalculateOverflow();
});
observer.observe(container);
return () => observer.disconnect();
}, [isActive, recalculateOverflow]);
const overflowTabValuesSet = useMemo(
() => new Set(overflowTabValues),
[overflowTabValues],
);
const visibleTabs = useMemo(
() => tabs.filter((tab) => !overflowTabValuesSet.has(tab.value)),
[tabs, overflowTabValuesSet],
);
const overflowTabs = useMemo(
() => tabs.filter((tab) => overflowTabValuesSet.has(tab.value)),
[tabs, overflowTabValuesSet],
);
const getTabMeasureProps = useCallback((tabValue: string) => {
return { [DATA_ATTR_TAB_VALUE]: tabValue };
}, []);
return {
containerRef,
visibleTabs,
overflowTabs,
getTabMeasureProps,
};
};
-72
View File
@@ -258,78 +258,6 @@ describe(usePaginatedQuery.name, () => {
});
});
describe("Capped count behavior", () => {
const mockQueryKey = vi.fn(() => ["mock"]);
// Returns count 2001 (capped) with items on pages up to page 84
// (84 * 25 = 2100 items total).
const mockCappedQueryFn = vi.fn(({ pageNumber, limit }) => {
const totalItems = 2100;
const offset = (pageNumber - 1) * limit;
// Returns 0 items when the requested page is past the end, simulating
// an empty server response.
const itemsOnPage = Math.max(0, Math.min(limit, totalItems - offset));
return Promise.resolve({
data: new Array(itemsOnPage).fill(pageNumber),
count: 2001,
count_cap: 2000,
});
});
it("Caps totalRecords at 2000 when count exceeds cap", async () => {
const { result } = await render({
queryKey: mockQueryKey,
queryFn: mockCappedQueryFn,
});
await waitFor(() => expect(result.current.isSuccess).toBe(true));
expect(result.current.totalRecords).toBe(2000);
});
it("hasNextPage is true when count is capped", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=80",
);
await waitFor(() => expect(result.current.isSuccess).toBe(true));
expect(result.current.hasNextPage).toBe(true);
});
it("hasPreviousPage is true when count is capped and page is beyond cap", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=83",
);
await waitFor(() => expect(result.current.isSuccess).toBe(true));
expect(result.current.hasPreviousPage).toBe(true);
});
it("Does not redirect to last page when count is capped and page is valid", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=83",
);
await waitFor(() => expect(result.current.isSuccess).toBe(true));
// Should stay on page 83 — not redirect to page 80.
expect(result.current.currentPage).toBe(83);
});
it("Redirects to last known page when navigating beyond actual data", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=999",
);
// Page 999 has no items. Should redirect to page 81
// (ceil(2001 / 25) = 81), the last page guaranteed to
// have data.
await waitFor(() => expect(result.current.currentPage).toBe(81));
});
});
describe("Passing in searchParams property", () => {
const mockQueryKey = vi.fn(() => ["mock"]);
const mockQueryFn = vi.fn(({ pageNumber, limit }) =>
+8 -50
View File
@@ -144,44 +144,16 @@ export function usePaginatedQuery<
placeholderData: keepPreviousData,
});
const count = query.data?.count;
const countCap = query.data?.count_cap;
const countIsCapped =
countCap !== undefined &&
countCap > 0 &&
count !== undefined &&
count > countCap;
const totalRecords = countIsCapped ? countCap : count;
let totalPages =
totalRecords !== undefined
? Math.max(
Math.ceil(totalRecords / limit),
// True count is not known; let them navigate forward
// until they hit an empty page (checked below).
countIsCapped ? currentPage : 0,
)
: undefined;
// When the true count is unknown, the user can navigate past
// all actual data. If that happens, we need to redirect (via
// updatePageIfInvalid) to the last page guaranteed to be not
// empty.
const pageIsEmpty =
query.data != null &&
!Object.values(query.data).some((v) => Array.isArray(v) && v.length > 0);
if (pageIsEmpty) {
totalPages = count !== undefined ? Math.ceil(count / limit) : 1;
}
const totalRecords = query.data?.count;
const totalPages =
totalRecords !== undefined ? Math.ceil(totalRecords / limit) : undefined;
const hasNextPage =
totalRecords !== undefined &&
((countIsCapped && !pageIsEmpty) ||
limit + currentPageOffset < totalRecords);
totalRecords !== undefined && limit + currentPageOffset < totalRecords;
const hasPreviousPage =
totalRecords !== undefined &&
currentPage > 1 &&
((countIsCapped && !pageIsEmpty) ||
currentPageOffset - limit < totalRecords);
currentPageOffset - limit < totalRecords;
const queryClient = useQueryClient();
const prefetchPage = useEffectEvent((newPage: number) => {
@@ -252,14 +224,10 @@ export function usePaginatedQuery<
});
useEffect(() => {
if (
!query.isFetching &&
totalPages !== undefined &&
currentPage > totalPages
) {
if (!query.isFetching && totalPages !== undefined) {
void updatePageIfInvalid(totalPages);
}
}, [updatePageIfInvalid, query.isFetching, totalPages, currentPage]);
}, [updatePageIfInvalid, query.isFetching, totalPages]);
const onPageChange = (newPage: number) => {
// Page 1 is the only page that can be safely navigated to without knowing
@@ -268,12 +236,7 @@ export function usePaginatedQuery<
return;
}
// If the true count is unknown, we allow navigating past the
// known page range.
const upperBound = countIsCapped
? Number.MAX_SAFE_INTEGER
: (totalPages ?? 1);
const cleanedInput = clamp(Math.trunc(newPage), 1, upperBound);
const cleanedInput = clamp(Math.trunc(newPage), 1, totalPages ?? 1);
if (Number.isNaN(cleanedInput)) {
return;
}
@@ -311,7 +274,6 @@ export function usePaginatedQuery<
totalRecords: totalRecords as number,
totalPages: totalPages as number,
currentOffsetStart: currentPageOffset + 1,
countIsCapped,
}
: {
isSuccess: false,
@@ -320,7 +282,6 @@ export function usePaginatedQuery<
totalRecords: undefined,
totalPages: undefined,
currentOffsetStart: undefined,
countIsCapped: false as const,
}),
};
@@ -362,7 +323,6 @@ export type PaginationResultInfo = {
totalRecords: undefined;
totalPages: undefined;
currentOffsetStart: undefined;
countIsCapped: false;
}
| {
isSuccess: true;
@@ -371,7 +331,6 @@ export type PaginationResultInfo = {
totalRecords: number;
totalPages: number;
currentOffsetStart: number;
countIsCapped: boolean;
}
);
@@ -458,7 +417,6 @@ type QueryPageParamsWithPayload<TPayload = never> = QueryPageParams & {
*/
export type PaginatedData = {
count: number;
count_cap?: number;
};
/**
@@ -2,7 +2,7 @@ import type { Meta, StoryObj } from "@storybook/react-vite";
import { expect, spyOn, userEvent, waitFor, within } from "storybook/test";
import { API } from "#/api/api";
import { workspaceAgentContainersKey } from "#/api/queries/workspaces";
import type { WorkspaceAgentLogSource } from "#/api/typesGenerated";
import type * as TypesGen from "#/api/typesGenerated";
import { getPreferredProxy } from "#/contexts/ProxyContext";
import { chromatic } from "#/testHelpers/chromatic";
import * as M from "#/testHelpers/entities";
@@ -76,8 +76,6 @@ const defaultAgentMetadata = [
},
];
const fixedLogTimestamp = "2021-05-05T00:00:00.000Z";
const logs = [
"\x1b[91mCloning Git repository...",
"\x1b[2;37;41mStarting Docker Daemon...",
@@ -89,10 +87,10 @@ const logs = [
level: "info",
output: line,
source_id: M.MockWorkspaceAgentLogSource.id,
created_at: fixedLogTimestamp,
created_at: new Date().toISOString(),
}));
const installScriptLogSource: WorkspaceAgentLogSource = {
const installScriptLogSource: TypesGen.WorkspaceAgentLogSource = {
...M.MockWorkspaceAgentLogSource,
id: "f2ee4b8d-b09d-4f4e-a1f1-5e4adf7d53bb",
display_name: "Install Script",
@@ -104,24 +102,60 @@ const tabbedLogs = [
level: "info",
output: "startup: preparing workspace",
source_id: M.MockWorkspaceAgentLogSource.id,
created_at: fixedLogTimestamp,
created_at: new Date().toISOString(),
},
{
id: 101,
level: "info",
output: "install: pnpm install",
source_id: installScriptLogSource.id,
created_at: fixedLogTimestamp,
created_at: new Date().toISOString(),
},
{
id: 102,
level: "info",
output: "install: setup complete",
source_id: installScriptLogSource.id,
created_at: fixedLogTimestamp,
created_at: new Date().toISOString(),
},
];
const overflowLogSources: TypesGen.WorkspaceAgentLogSource[] = [
M.MockWorkspaceAgentLogSource,
{
...M.MockWorkspaceAgentLogSource,
id: "58f5db69-5f78-496f-bce1-0686f5525aa1",
display_name: "code-server",
icon: "/icon/code.svg",
},
{
...M.MockWorkspaceAgentLogSource,
id: "f39d758c-bce2-4f41-8d70-58fdb1f0f729",
display_name: "Install and start AgentAPI",
icon: "/icon/claude.svg",
},
{
...M.MockWorkspaceAgentLogSource,
id: "bf7529b8-1787-4a20-b54f-eb894680e48f",
display_name: "Mux",
icon: "/icon/mux.svg",
},
{
...M.MockWorkspaceAgentLogSource,
id: "0d6ebde6-c534-4551-9f91-bfd98bfb04f4",
display_name: "Portable Desktop",
icon: "/icon/portable-desktop.svg",
},
];
const overflowLogs = overflowLogSources.map((source, index) => ({
id: 200 + index,
level: "info",
output: `${source.display_name}: line`,
source_id: source.id,
created_at: new Date().toISOString(),
}));
const meta: Meta<typeof AgentRow> = {
title: "components/AgentRow",
component: AgentRow,
@@ -404,3 +438,44 @@ export const LogsTabs: Story = {
await expect(canvas.getByText("install: pnpm install")).toBeVisible();
},
};
export const LogsTabsOverflow: Story = {
args: {
agent: {
...M.MockWorkspaceAgentReady,
logs_length: overflowLogs.length,
log_sources: overflowLogSources,
},
},
parameters: {
webSocket: [
{
event: "message",
data: JSON.stringify(overflowLogs),
},
],
},
render: (args) => (
<div className="max-w-[320px]">
<AgentRow {...args} />
</div>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const page = within(canvasElement.ownerDocument.body);
await userEvent.click(canvas.getByRole("button", { name: "Logs" }));
await userEvent.click(
canvas.getByRole("button", { name: "More log tabs" }),
);
const overflowItems = await page.findAllByRole("menuitemradio");
const selectedItem = overflowItems[0];
const selectedSource = selectedItem.textContent;
if (!selectedSource) {
throw new Error("Overflow menu item must have text content.");
}
await userEvent.click(selectedItem);
await waitFor(() =>
expect(canvas.getByText(`${selectedSource}: line`)).toBeVisible(),
);
},
};
+60 -69
View File
@@ -8,9 +8,10 @@ import {
} from "lucide-react";
import {
type FC,
type ReactNode,
useCallback,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from "react";
@@ -41,7 +42,7 @@ import {
TabsList,
TabsTrigger,
} from "#/components/Tabs/Tabs";
import { useKebabMenu } from "#/components/Tabs/utils/useKebabMenu";
import { useTabOverflowKebabMenu } from "#/components/Tabs/utils";
import { useProxy } from "#/contexts/ProxyContext";
import { useClipboard } from "#/hooks/useClipboard";
import { useFeatureVisibility } from "#/modules/dashboard/useFeatureVisibility";
@@ -161,7 +162,7 @@ export const AgentRow: FC<AgentRowProps> = ({
// This is a bit of a hack on the react-window API to get the scroll position.
// If we're scrolled to the bottom, we want to keep the list scrolled to the bottom.
// This makes it feel similar to a terminal that auto-scrolls downwards!
const handleLogScroll = (props: ListOnScrollProps) => {
const handleLogScroll = useCallback((props: ListOnScrollProps) => {
if (
props.scrollOffset === 0 ||
props.scrollUpdateWasRequested ||
@@ -178,7 +179,7 @@ export const AgentRow: FC<AgentRowProps> = ({
logListDivRef.current.scrollHeight -
(props.scrollOffset + parent.clientHeight);
setBottomOfLogs(distanceFromBottom < AGENT_LOG_LINE_HEIGHT);
};
}, []);
const devcontainers = useAgentContainers(agent);
@@ -210,56 +211,59 @@ export const AgentRow: FC<AgentRowProps> = ({
);
const [selectedLogTab, setSelectedLogTab] = useState("all");
const sourceLogTabs = agent.log_sources
.filter((logSource) => {
// Remove the logSources that have no entries.
return agentLogs.some(
(log) =>
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
);
})
.map((logSource) => ({
// Show the icon for the log source if it has one.
// In the startup script case, we show a bespoke play icon.
startIcon: logSource.icon ? (
<ExternalImage
src={logSource.icon}
alt=""
className="size-icon-xs shrink-0"
/>
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
<PlayIcon className="size-icon-xs shrink-0" />
) : null,
title: logSource.display_name,
value: logSource.id,
}));
const startupScriptLogTab = sourceLogTabs.find(
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
);
const sortedSourceLogTabs = sourceLogTabs
.filter((tab) => tab !== startupScriptLogTab)
.sort((a, b) => a.title.localeCompare(b.title));
const logTabs: {
startIcon?: ReactNode;
title: string;
value: string;
}[] = [
{
title: "All Logs",
value: "all",
},
...(startupScriptLogTab ? [startupScriptLogTab] : []),
...sortedSourceLogTabs,
];
const logTabs = useMemo(() => {
const sourceLogTabs = agent.log_sources
.filter((logSource) => {
// Remove the logSources that have no entries.
return agentLogs.some(
(log) =>
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
);
})
.map((logSource) => ({
// Show the icon for the log source if it has one.
// In the startup script case, we show a bespoke play icon.
startIcon: logSource.icon ? (
<ExternalImage
src={logSource.icon}
alt=""
className="size-icon-xs shrink-0"
/>
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
<PlayIcon className="size-icon-xs shrink-0" />
) : null,
title: logSource.display_name,
value: logSource.id,
}));
const startupScriptLogTab = sourceLogTabs.find(
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
);
const sortedSourceLogTabs = sourceLogTabs
.filter((tab) => tab !== startupScriptLogTab)
.sort((a, b) => a.title.localeCompare(b.title));
return [
{
title: "All Logs",
value: "all",
},
...(startupScriptLogTab ? [startupScriptLogTab] : []),
...sortedSourceLogTabs,
] as {
startIcon?: React.ReactNode;
title: string;
value: string;
}[];
}, [agent.log_sources, agentLogs]);
const {
containerRef: logTabsListContainerRef,
visibleTabs: visibleLogTabs,
overflowTabs: overflowLogTabs,
getTabMeasureProps,
} = useKebabMenu({
} = useTabOverflowKebabMenu({
tabs: logTabs,
enabled: true,
isActive: showLogs,
alwaysVisibleTabsCount: 1,
});
const overflowLogTabValuesSet = new Set(
overflowLogTabs.map((tab) => tab.value),
@@ -275,29 +279,16 @@ export const AgentRow: FC<AgentRowProps> = ({
level: log.level,
sourceId: log.source_id,
}));
const allLogsText = agentLogs.map((log) => log.output).join("\n");
const selectedLogsText = selectedLogs.map((log) => log.output).join("\n");
const hasSelectedLogs = selectedLogs.length > 0;
const hasAnyLogs = agentLogs.length > 0;
const { showCopiedSuccess, copyToClipboard } = useClipboard();
const downloadableLogSets = logTabs
.filter((tab) => tab.value !== "all")
.map((tab) => {
const logsText = agentLogs
.filter((log) => log.source_id === tab.value)
.map((log) => log.output)
.join("\n");
const filenameSuffix = tab.title
.toLowerCase()
.replaceAll(/[^a-z0-9]+/g, "-")
.replaceAll(/(^-|-$)/g, "");
return {
label: tab.title,
filenameSuffix: filenameSuffix || tab.value,
logsText,
startIcon: tab.startIcon,
};
});
const selectedLogTabTitle =
logTabs.find((tab) => tab.value === selectedLogTab)?.title ?? "Logs";
const sanitizedTabTitle = selectedLogTabTitle
.toLowerCase()
.replaceAll(/[^a-z0-9]+/g, "-")
.replaceAll(/(^-|-$)/g, "");
const logFilenameSuffix = sanitizedTabTitle || "logs";
return (
<div
@@ -556,9 +547,9 @@ export const AgentRow: FC<AgentRowProps> = ({
</Button>
<DownloadSelectedAgentLogsButton
agentName={agent.name}
logSets={downloadableLogSets}
allLogsText={allLogsText}
disabled={!hasAnyLogs}
filenameSuffix={logFilenameSuffix}
logsText={selectedLogsText}
disabled={!hasSelectedLogs}
/>
</div>
</div>
@@ -1,27 +1,14 @@
import { saveAs } from "file-saver";
import { ChevronDownIcon, DownloadIcon, PackageIcon } from "lucide-react";
import { type FC, type ReactNode, useState } from "react";
import { DownloadIcon } from "lucide-react";
import { type FC, useState } from "react";
import { toast } from "sonner";
import { getErrorDetail } from "#/api/errors";
import { Button } from "#/components/Button/Button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "#/components/DropdownMenu/DropdownMenu";
type DownloadableLogSet = {
label: string;
filenameSuffix: string;
logsText: string;
startIcon?: ReactNode;
};
type DownloadSelectedAgentLogsButtonProps = {
agentName: string;
logSets: readonly DownloadableLogSet[];
allLogsText: string;
filenameSuffix: string;
logsText: string;
disabled?: boolean;
download?: (file: Blob, filename: string) => void | Promise<void>;
};
@@ -30,13 +17,13 @@ export const DownloadSelectedAgentLogsButton: FC<
DownloadSelectedAgentLogsButtonProps
> = ({
agentName,
logSets,
allLogsText,
filenameSuffix,
logsText,
disabled = false,
download = saveAs,
}) => {
const [isDownloading, setIsDownloading] = useState(false);
const downloadLogs = async (logsText: string, filenameSuffix: string) => {
const handleDownload = async () => {
try {
setIsDownloading(true);
const file = new Blob([logsText], { type: "text/plain" });
@@ -50,40 +37,15 @@ export const DownloadSelectedAgentLogsButton: FC<
}
};
const hasAllLogs = allLogsText.length > 0;
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="subtle" size="sm" disabled={disabled || isDownloading}>
<DownloadIcon />
{isDownloading ? "Downloading..." : "Download logs"}
<ChevronDownIcon className="size-icon-sm" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
disabled={!hasAllLogs}
onSelect={() => {
downloadLogs(allLogsText, "all-logs");
}}
>
<PackageIcon />
Download all logs
</DropdownMenuItem>
{logSets.map((logSet) => (
<DropdownMenuItem
key={logSet.filenameSuffix}
disabled={logSet.logsText.length === 0}
onSelect={() => {
downloadLogs(logSet.logsText, logSet.filenameSuffix);
}}
>
{logSet.startIcon}
<span>Download {logSet.label}</span>
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
<Button
variant="subtle"
size="sm"
disabled={disabled || isDownloading}
onClick={handleDownload}
>
<DownloadIcon />
{isDownloading ? "Downloading..." : "Download logs"}
</Button>
);
};
@@ -1,19 +1,22 @@
import dayjs, { type Dayjs } from "dayjs";
import { type FC, useState } from "react";
import { useQuery } from "react-query";
import { useSearchParams } from "react-router";
import { chatCostSummary } from "#/api/queries/chats";
import type { DateRangeValue } from "#/components/DateRangePicker/DateRangePicker";
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
import { useAuthContext } from "#/contexts/auth/AuthProvider";
import { AgentAnalyticsPageView } from "./AgentAnalyticsPageView";
import { AgentPageHeader } from "./components/AgentPageHeader";
const createDateRange = (now?: Dayjs) => {
const startDateSearchParam = "startDate";
const endDateSearchParam = "endDate";
const getDefaultDateRange = (now?: Dayjs): DateRangeValue => {
const end = now ?? dayjs();
const start = end.subtract(30, "day");
return {
startDate: start.toISOString(),
endDate: end.toISOString(),
rangeLabel: `${start.format("MMM D")} ${end.format("MMM D, YYYY")}`,
startDate: end.subtract(30, "day").toDate(),
endDate: end.toDate(),
};
};
@@ -24,13 +27,44 @@ interface AgentAnalyticsPageProps {
const AgentAnalyticsPage: FC<AgentAnalyticsPageProps> = ({ now }) => {
const { user } = useAuthContext();
const [anchor] = useState<Dayjs>(() => dayjs());
const dateRange = createDateRange(now ?? anchor);
const [searchParams, setSearchParams] = useSearchParams();
const startDateParam = searchParams.get(startDateSearchParam)?.trim() ?? "";
const endDateParam = searchParams.get(endDateSearchParam)?.trim() ?? "";
const [defaultDateRange] = useState(() => getDefaultDateRange(now));
let dateRange = defaultDateRange;
let hasExplicitDateRange = false;
if (startDateParam && endDateParam) {
const parsedStartDate = new Date(startDateParam);
const parsedEndDate = new Date(endDateParam);
if (
!Number.isNaN(parsedStartDate.getTime()) &&
!Number.isNaN(parsedEndDate.getTime()) &&
parsedStartDate.getTime() <= parsedEndDate.getTime()
) {
dateRange = {
startDate: parsedStartDate,
endDate: parsedEndDate,
};
hasExplicitDateRange = true;
}
}
const onDateRangeChange = (value: DateRangeValue) => {
setSearchParams((prev) => {
const next = new URLSearchParams(prev);
next.set(startDateSearchParam, value.startDate.toISOString());
next.set(endDateSearchParam, value.endDate.toISOString());
return next;
});
};
const summaryQuery = useQuery({
...chatCostSummary(user?.id ?? "me", {
start_date: dateRange.startDate,
end_date: dateRange.endDate,
start_date: dateRange.startDate.toISOString(),
end_date: dateRange.endDate.toISOString(),
}),
enabled: Boolean(user?.id),
});
@@ -43,7 +77,9 @@ const AgentAnalyticsPage: FC<AgentAnalyticsPageProps> = ({ now }) => {
isLoading={summaryQuery.isLoading}
error={summaryQuery.error}
onRetry={() => void summaryQuery.refetch()}
rangeLabel={dateRange.rangeLabel}
dateRange={dateRange}
onDateRangeChange={onDateRangeChange}
hasExplicitDateRange={hasExplicitDateRange}
/>
</ScrollArea>
);
@@ -1,15 +1,21 @@
import { BarChart3Icon } from "lucide-react";
import type { FC } from "react";
import type { ChatCostSummary } from "#/api/typesGenerated";
import {
DateRangePicker,
type DateRangeValue,
} from "#/components/DateRangePicker/DateRangePicker";
import { ChatCostSummaryView } from "./components/ChatCostSummaryView";
import { SectionHeader } from "./components/SectionHeader";
import { toInclusiveDateRange } from "./utils/dateRange";
interface AgentAnalyticsPageViewProps {
summary: ChatCostSummary | undefined;
isLoading: boolean;
error: unknown;
onRetry: () => void;
rangeLabel: string;
dateRange: DateRangeValue;
onDateRangeChange: (value: DateRangeValue) => void;
hasExplicitDateRange: boolean;
}
export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
@@ -17,8 +23,15 @@ export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
isLoading,
error,
onRetry,
rangeLabel,
dateRange,
onDateRangeChange,
hasExplicitDateRange,
}) => {
const displayDateRange = toInclusiveDateRange(
dateRange,
hasExplicitDateRange,
);
return (
<div className="flex flex-col p-4 pt-8">
<div className="mx-auto w-full max-w-3xl">
@@ -26,10 +39,10 @@ export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
label="Analytics"
description="Review your personal Coder Agents usage and cost breakdowns."
action={
<div className="flex items-center gap-2 text-xs text-content-secondary">
<BarChart3Icon className="h-4 w-4" />
<span>{rangeLabel}</span>
</div>
<DateRangePicker
value={displayDateRange}
onChange={onDateRangeChange}
/>
}
/>
+18 -24
View File
@@ -646,24 +646,17 @@ const AgentChatPage: FC = () => {
const isRegenerateTitleDisabled = isArchived || isRegeneratingThisChat;
const chatLastModelConfigID = chatRecord?.last_model_config_id;
// Destructure mutation results directly so the React Compiler
// tracks stable primitives/functions instead of the whole result
// object (TanStack Query v5 recreates it every render via object
// spread). Keeping no intermediate variable prevents future code
// from accidentally closing over the unstable object.
const { isPending: isSendPending, mutateAsync: sendMessage } = useMutation(
const sendMutation = useMutation(
createChatMessage(queryClient, agentId ?? ""),
);
const { isPending: isEditPending, mutateAsync: editMessage } = useMutation(
editChatMessage(queryClient, agentId ?? ""),
);
const { isPending: isInterruptPending, mutateAsync: interrupt } = useMutation(
const editMutation = useMutation(editChatMessage(queryClient, agentId ?? ""));
const interruptMutation = useMutation(
interruptChat(queryClient, agentId ?? ""),
);
const { mutateAsync: deleteQueuedMessage } = useMutation(
const deleteQueuedMutation = useMutation(
deleteChatQueuedMessage(queryClient, agentId ?? ""),
);
const { mutateAsync: promoteQueuedMessage } = useMutation(
const promoteQueuedMutation = useMutation(
promoteChatQueuedMessage(queryClient, agentId ?? ""),
);
@@ -761,7 +754,9 @@ const AgentChatPage: FC = () => {
hasUserFixableModelProviders,
});
const isSubmissionPending =
isSendPending || isEditPending || isInterruptPending;
sendMutation.isPending ||
editMutation.isPending ||
interruptMutation.isPending;
const isInputDisabled = !hasModelOptions || isArchived;
const handleUsageLimitError = (error: unknown): void => {
@@ -847,7 +842,7 @@ const AgentChatPage: FC = () => {
setPendingEditMessageId(editedMessageID);
scrollToBottomRef.current?.();
try {
await editMessage({
await editMutation.mutateAsync({
messageId: editedMessageID,
req: request,
});
@@ -878,9 +873,9 @@ const AgentChatPage: FC = () => {
// For queued sends the WebSocket status events handle
// clearing; for non-queued sends we clear explicitly
// below. Clearing eagerly causes a visible cutoff.
let response: Awaited<ReturnType<typeof sendMessage>>;
let response: Awaited<ReturnType<typeof sendMutation.mutateAsync>>;
try {
response = await sendMessage(request);
response = await sendMutation.mutateAsync(request);
} catch (error) {
handleUsageLimitError(error);
throw error;
@@ -913,10 +908,10 @@ const AgentChatPage: FC = () => {
};
const handleInterrupt = () => {
if (!agentId || isInterruptPending) {
if (!agentId || interruptMutation.isPending) {
return;
}
void interrupt();
void interruptMutation.mutateAsync();
};
const handleDeleteQueuedMessage = async (id: number) => {
@@ -925,7 +920,7 @@ const AgentChatPage: FC = () => {
previousQueuedMessages.filter((message) => message.id !== id),
);
try {
await deleteQueuedMessage(id);
await deleteQueuedMutation.mutateAsync(id);
} catch (error) {
store.setQueuedMessages(previousQueuedMessages);
throw error;
@@ -946,7 +941,7 @@ const AgentChatPage: FC = () => {
store.clearStreamError();
store.setChatStatus("pending");
try {
const promotedMessage = await promoteQueuedMessage(id);
const promotedMessage = await promoteQueuedMutation.mutateAsync(id);
// Insert the promoted message into the store and cache
// immediately so it appears in the timeline without
// waiting for the WebSocket to deliver it.
@@ -995,8 +990,7 @@ const AgentChatPage: FC = () => {
? `ssh ${workspaceAgent.name}.${workspace.name}.${workspace.owner_name}.${sshConfigQuery.data.hostname_suffix}`
: undefined;
// See mutation destructuring comment above (React Compiler).
const { mutate: generateKey } = useMutation({
const generateKeyMutation = useMutation({
mutationFn: () => API.getApiKey(),
});
@@ -1011,7 +1005,7 @@ const AgentChatPage: FC = () => {
const repoRoots = Array.from(gitWatcher.repositories.keys()).sort();
const folder = repoRoots[0] ?? workspaceAgent.expanded_directory;
generateKey(undefined, {
generateKeyMutation.mutate(undefined, {
onSuccess: ({ key }) => {
location.href = getVSCodeHref(editor, {
owner: workspace.owner_name,
@@ -1147,7 +1141,7 @@ const AgentChatPage: FC = () => {
compressionThreshold={compressionThreshold}
isInputDisabled={isInputDisabled}
isSubmissionPending={isSubmissionPending}
isInterruptPending={isInterruptPending}
isInterruptPending={interruptMutation.isPending}
isSidebarCollapsed={isSidebarCollapsed}
onToggleSidebarCollapsed={onToggleSidebarCollapsed}
showSidebarPanel={showSidebarPanel}
@@ -372,7 +372,7 @@ export const DefaultAutostopNotVisibleToNonAdmin: Story = {
const canvas = within(canvasElement);
// Personal Instructions should be visible.
await canvas.findByText("Personal Instructions");
await canvas.findByText("Instructions");
// Admin-only sections should not be present.
expect(canvas.queryByText("Workspace Autostop Fallback")).toBeNull();
@@ -412,7 +412,7 @@ export const InvisibleUnicodeWarningUserPrompt: Story = {
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
await canvas.findByText("Personal Instructions");
await canvas.findByText("Instructions");
const alert = await canvas.findByText(/invisible Unicode/);
expect(alert).toBeInTheDocument();
@@ -454,7 +454,7 @@ export const NoWarningForCleanPrompt: Story = {
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
await canvas.findByText("Personal Instructions");
await canvas.findByText("Instructions");
await canvas.findByText("System Instructions");
expect(canvas.queryByText(/invisible Unicode/)).toBeNull();
@@ -9,6 +9,13 @@ import { Switch } from "#/components/Switch/Switch";
import { cn } from "#/utils/cn";
import { countInvisibleCharacters } from "#/utils/invisibleUnicode";
import { AdminBadge } from "./components/AdminBadge";
import {
CollapsibleSection,
CollapsibleSectionContent,
CollapsibleSectionDescription,
CollapsibleSectionHeader,
CollapsibleSectionTitle,
} from "./components/CollapsibleSection";
import { DurationField } from "./components/DurationField/DurationField";
import { SectionHeader } from "./components/SectionHeader";
import { TextPreviewDialog } from "./components/TextPreviewDialog";
@@ -246,142 +253,43 @@ export const AgentSettingsBehaviorPageView: FC<
};
return (
<>
<div className="space-y-6">
<SectionHeader
label="Behavior"
description="Custom instructions that shape how the agent responds in your conversations."
/>
{/* ── Personal prompt (always visible) ── */}
<form
className="space-y-2"
onSubmit={(event) => void handleSaveUserPrompt(event)}
>
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Personal Instructions
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Applied to all your conversations. Only visible to you.
</p>
<TextareaAutosize
className={cn(
textareaBaseClassName,
isUserPromptOverflowing && textareaOverflowClassName,
)}
placeholder="Additional behavior, style, and tone preferences"
value={userPromptDraft}
onChange={(event) => setLocalUserEdit(event.target.value)}
onHeightChange={(height) =>
setIsUserPromptOverflowing(height >= textareaMaxHeight)
}
disabled={isPromptSaving}
minRows={1}
/>
{userInvisibleCharCount > 0 && (
<Alert severity="warning">
<AlertDescription>
This text contains {userInvisibleCharCount} invisible Unicode{" "}
{userInvisibleCharCount !== 1 ? "characters" : "character"} that
could hide content. These will be stripped on save.
</AlertDescription>
</Alert>
)}
<div className="flex justify-end gap-2">
<Button
size="sm"
variant="outline"
type="button"
onClick={() => setLocalUserEdit("")}
disabled={isPromptSaving || !userPromptDraft}
>
Clear
</Button>
<Button
size="sm"
type="submit"
disabled={isPromptSaving || !isUserPromptDirty}
>
Save
</Button>
</div>
{isSaveUserPromptError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save personal instructions.
</p>
)}
</form>
<hr className="my-5 border-0 border-t border-solid border-border" />
<UserCompactionThresholdSettings
modelConfigs={modelConfigsData ?? []}
modelConfigsError={modelConfigsError}
isLoadingModelConfigs={isLoadingModelConfigs}
thresholds={thresholds}
isThresholdsLoading={isThresholdsLoading}
thresholdsError={thresholdsError}
onSaveThreshold={onSaveThreshold}
onResetThreshold={onResetThreshold}
/>
{/* ── Admin system prompt (admin only) ── */}
{canSetSystemPrompt && (
<>
<hr className="my-5 border-0 border-t border-solid border-border" />
{/* Personal prompt (always visible) */}
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Instructions</CollapsibleSectionTitle>
<CollapsibleSectionDescription>Applied to all your conversations. Only visible to you.</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<form
className="space-y-2"
onSubmit={(event) => void handleSaveSystemPrompt(event)}
onSubmit={(event) => void handleSaveUserPrompt(event)}
>
<div className="flex items-center gap-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
System Instructions
</h3>
<AdminBadge />
</div>
<div className="flex items-center justify-between gap-4">
<div className="flex min-w-0 items-center gap-2 text-xs font-medium text-content-primary">
<span>Include Coder Agents default system prompt</span>
<Button
size="xs"
variant="subtle"
type="button"
onClick={() => setShowDefaultPromptPreview(true)}
disabled={!hasLoadedSystemPrompt}
className="min-w-0 px-0 text-content-link hover:text-content-link"
>
Preview
</Button>
</div>
<Switch
checked={includeDefaultDraft}
onCheckedChange={setLocalIncludeDefault}
aria-label="Include Coder Agents default system prompt"
disabled={isSystemPromptDisabled}
/>
</div>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
{includeDefaultDraft
? "The built-in Coder Agents prompt is prepended. Additional instructions below are appended."
: "Only the additional instructions below are used. When empty, no deployment-wide system prompt is sent."}
</p>
<TextareaAutosize
className={cn(
textareaBaseClassName,
isSystemPromptOverflowing && textareaOverflowClassName,
isUserPromptOverflowing && textareaOverflowClassName,
)}
placeholder="Additional instructions for all users"
value={systemPromptDraft}
onChange={(event) => setLocalEdit(event.target.value)}
placeholder="Additional behavior, style, and tone preferences"
value={userPromptDraft}
onChange={(event) => setLocalUserEdit(event.target.value)}
onHeightChange={(height) =>
setIsSystemPromptOverflowing(height >= textareaMaxHeight)
setIsUserPromptOverflowing(height >= textareaMaxHeight)
}
disabled={isSystemPromptDisabled}
disabled={isPromptSaving}
minRows={1}
/>
{systemInvisibleCharCount > 0 && (
{userInvisibleCharCount > 0 && (
<Alert severity="warning">
<AlertDescription>
This text contains {systemInvisibleCharCount} invisible
Unicode{" "}
{systemInvisibleCharCount !== 1 ? "characters" : "character"}{" "}
that could hide content. These will be stripped on save.
This text contains {userInvisibleCharCount} invisible Unicode{" "}
{userInvisibleCharCount !== 1 ? "characters" : "character"} that
could hide content. These will be stripped on save.
</AlertDescription>
</Alert>
)}
@@ -390,164 +298,304 @@ export const AgentSettingsBehaviorPageView: FC<
size="sm"
variant="outline"
type="button"
onClick={() => setLocalEdit("")}
disabled={isSystemPromptDisabled || !systemPromptDraft}
onClick={() => setLocalUserEdit("")}
disabled={isPromptSaving || !userPromptDraft}
>
Clear
</Button>
<Button
size="sm"
type="submit"
disabled={isSystemPromptDisabled || !isSystemPromptDirty}
disabled={isPromptSaving || !isUserPromptDirty}
>
Save
</Button>{" "}
</Button>
</div>
{isSaveSystemPromptError && (
{isSaveUserPromptError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save system prompt.
Failed to save personal instructions.
</p>
)}
</form>
<hr className="my-5 border-0 border-t border-solid border-border" />
<div className="space-y-2">
<div className="flex items-center gap-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Virtual Desktop
</h3>
<AdminBadge />
</div>
<div className="flex items-center justify-between gap-4">
<div className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
<p className="m-0">
Allow agents to use a virtual, graphical desktop within
workspaces. Requires the{" "}
<Link
href="https://registry.coder.com/modules/coder/portabledesktop"
target="_blank"
size="sm"
>
portabledesktop module
</Link>{" "}
to be installed in the workspace and the Anthropic provider to
be configured.
</p>
<p className="mt-2 mb-0 font-semibold text-content-secondary">
Warning: This is a work-in-progress feature, and you're likely
to encounter bugs if you enable it.
</p>
</CollapsibleSectionContent>
</CollapsibleSection>
{/* Context compaction */}
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Context compaction</CollapsibleSectionTitle>
<CollapsibleSectionDescription>Control when conversation context is automatically summarized for each model. Setting 100% means the conversation will never auto-compact.</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<UserCompactionThresholdSettings
hideHeader
modelConfigs={modelConfigsData ?? []}
modelConfigsError={modelConfigsError}
isLoadingModelConfigs={isLoadingModelConfigs}
thresholds={thresholds}
isThresholdsLoading={isThresholdsLoading}
thresholdsError={thresholdsError}
onSaveThreshold={onSaveThreshold}
onResetThreshold={onResetThreshold}
/>
</CollapsibleSectionContent>
</CollapsibleSection>
{/* Admin system prompt (admin only) */}
{canSetSystemPrompt && (
<>
<CollapsibleSection>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>System Instructions</CollapsibleSectionTitle>
<AdminBadge />
</div>
<Switch
checked={desktopEnabled}
onCheckedChange={(checked) =>
onSaveDesktopEnabled({ enable_desktop: checked })
}
aria-label="Enable"
disabled={isSavingDesktopEnabled}
/>
</div>
{isSaveDesktopEnabledError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save desktop setting.
</p>
)}
</div>
<hr className="my-5 border-0 border-t border-solid border-border" />
<form
className="space-y-2"
onSubmit={(event) => void handleSaveChatWorkspaceTTL(event)}
>
<div className="flex items-center gap-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Workspace Autostop Fallback
</h3>
<AdminBadge />
</div>
<div className="flex items-center justify-between gap-4">
<p className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
Set a default autostop for agent-created workspaces that don't
have one defined in their template. Template-defined autostop
rules always take precedence. Active conversations will extend
the stop time.
</p>
<Switch
checked={isAutostopEnabled}
onCheckedChange={handleToggleAutostop}
aria-label="Enable default autostop"
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
/>{" "}
</div>
{isAutostopEnabled && (
<DurationField
valueMs={ttlMs}
onChange={handleTTLChange}
label="Autostop Fallback"
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
error={isTTLOverMax || isTTLZero}
helperText={
isTTLZero
? "Duration must be greater than zero."
: isTTLOverMax
? "Must not exceed 30 days (720 hours)."
: undefined
}
/>
)}
{isAutostopEnabled && (
<div className="flex justify-end">
<Button
size="sm"
type="submit"
disabled={
isSavingWorkspaceTTL ||
!isTTLDirty ||
isTTLOverMax ||
isTTLZero
<CollapsibleSectionDescription>
{includeDefaultDraft
? "The built-in Coder Agents prompt is prepended. Additional instructions below are appended."
: "Only the additional instructions below are used. When empty, no deployment-wide system prompt is sent."}
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<form
className="space-y-2"
onSubmit={(event) => void handleSaveSystemPrompt(event)}
>
<div className="flex items-center justify-between gap-4">
<div className="flex min-w-0 items-center gap-2 text-xs font-medium text-content-primary">
<span>Include Coder Agents default system prompt</span>
<Button
size="xs"
variant="subtle"
type="button"
onClick={() => setShowDefaultPromptPreview(true)}
disabled={!hasLoadedSystemPrompt}
className="min-w-0 px-0 text-content-link hover:text-content-link"
>
Preview
</Button>
</div>
<Switch
checked={includeDefaultDraft}
onCheckedChange={setLocalIncludeDefault}
aria-label="Include Coder Agents default system prompt"
disabled={isSystemPromptDisabled}
/>
</div>
<TextareaAutosize
className={cn(
textareaBaseClassName,
isSystemPromptOverflowing && textareaOverflowClassName,
)}
placeholder="Additional instructions for all users"
value={systemPromptDraft}
onChange={(event) => setLocalEdit(event.target.value)}
onHeightChange={(height) =>
setIsSystemPromptOverflowing(height >= textareaMaxHeight)
}
>
Save
</Button>
disabled={isSystemPromptDisabled}
minRows={1}
/>
{systemInvisibleCharCount > 0 && (
<Alert severity="warning">
<AlertDescription>
This text contains {systemInvisibleCharCount} invisible
Unicode{" "}
{systemInvisibleCharCount !== 1
? "characters"
: "character"}{" "}
that could hide content. These will be stripped on save.
</AlertDescription>
</Alert>
)}
<div className="flex justify-end gap-2">
<Button
size="sm"
variant="outline"
type="button"
onClick={() => setLocalEdit("")}
disabled={isSystemPromptDisabled || !systemPromptDraft}
>
Clear
</Button>
<Button
size="sm"
type="submit"
disabled={isSystemPromptDisabled || !isSystemPromptDirty}
>
Save
</Button>
</div>
{isSaveSystemPromptError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save system prompt.
</p>
)}
</form>
</CollapsibleSectionContent>
</CollapsibleSection>
<CollapsibleSection>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>Virtual Desktop</CollapsibleSectionTitle>
<AdminBadge />
</div>
)}
{isSaveWorkspaceTTLError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save autostop setting.
</p>
)}
{isWorkspaceTTLLoadError && (
<p className="m-0 text-xs text-content-destructive">
Failed to load autostop setting.
</p>
)}
</form>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="flex items-start gap-4">
<div className="min-w-0 flex-1 space-y-1">
<span className="text-sm font-medium text-content-primary">
Enable virtual desktop
</span>
<p className="m-0 text-xs text-content-secondary">
Allow agents to use a virtual, graphical desktop within
workspaces. Requires the{" "}
<Link
href="https://registry.coder.com/modules/coder/portabledesktop"
target="_blank"
size="sm"
>
portabledesktop module
</Link>{" "}
to be installed in the workspace and the Anthropic provider to
be configured.
</p>
<p className="m-0 text-xs text-content-secondary font-semibold">
Warning: This is a work-in-progress feature, and you're likely
to encounter bugs if you enable it.
</p>
</div>
<Switch
checked={desktopEnabled}
onCheckedChange={(checked) =>
onSaveDesktopEnabled({ enable_desktop: checked })
}
aria-label="Enable"
disabled={isSavingDesktopEnabled}
/>
</div>{" "}
{isSaveDesktopEnabledError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save desktop setting.
</p>
)}
</CollapsibleSectionContent>
</CollapsibleSection>
<CollapsibleSection>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>Workspace Autostop Fallback</CollapsibleSectionTitle>
<AdminBadge />
</div>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="flex items-start gap-4">
<div className="min-w-0 flex-1 space-y-1">
<span className="text-sm font-medium text-content-primary">
Enable autostop fallback
</span>
<p className="m-0 text-xs text-content-secondary">
Set a default autostop for agent-created workspaces that don't
have one defined in their template. Template-defined autostop
rules always take precedence. Active conversations will extend
the stop time.
</p>
</div>
<Switch
checked={isAutostopEnabled}
onCheckedChange={handleToggleAutostop}
aria-label="Enable default autostop"
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
/>
</div>
<form
className="space-y-2 pt-4"
onSubmit={(event) => void handleSaveChatWorkspaceTTL(event)}
>
{" "}
{isAutostopEnabled && (
<DurationField
valueMs={ttlMs}
onChange={handleTTLChange}
label="Autostop Fallback"
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
error={isTTLOverMax || isTTLZero}
helperText={
isTTLZero
? "Duration must be greater than zero."
: isTTLOverMax
? "Must not exceed 30 days (720 hours)."
: undefined
}
/>
)}
{isAutostopEnabled && (
<div className="flex justify-end">
<Button
size="sm"
type="submit"
disabled={
isSavingWorkspaceTTL ||
!isTTLDirty ||
isTTLOverMax ||
isTTLZero
}
>
Save
</Button>
</div>
)}
{isSaveWorkspaceTTLError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save autostop setting.
</p>
)}
{isWorkspaceTTLLoadError && (
<p className="m-0 text-xs text-content-destructive">
Failed to load autostop setting.
</p>
)}
</form>
</CollapsibleSectionContent>
</CollapsibleSection>
</>
)}
<hr className="my-5 border-0 border-t border-solid border-border" />
{/* ── Kyleosophy toggle (always visible) ── */}
<div className="space-y-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Kyleosophy
</h3>
<div className="flex items-center justify-between gap-4">
<p className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
Replace the standard completion chime. IYKYK.
{kylesophyForced && (
<span className="ml-1 font-semibold">
Kyleosophy is mandatory on <code>dev.coder.com</code>.
{/* Kyleosophy toggle (always visible) */}
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Kyleosophy</CollapsibleSectionTitle>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="flex items-start gap-4">
<div className="min-w-0 flex-1 space-y-1">
<span className="text-sm font-medium text-content-primary">
Enable Kyleosophy
</span>
)}
</p>
<Switch
checked={kylesophyEnabled}
onCheckedChange={(checked) => {
setKylesophyEnabled(checked);
setLocalKylesophy(checked);
}}
aria-label="Enable Kyleosophy"
disabled={kylesophyForced}
/>
</div>
</div>
<p className="m-0 text-xs text-content-secondary">
{" "}
Replace the standard completion chime. IYKYK.
{kylesophyForced && (
<span className="ml-1 font-semibold">
Kyleosophy is mandatory on <code>dev.coder.com</code>.
</span>
)}
</p>
</div>
<Switch
checked={kylesophyEnabled}
onCheckedChange={(checked) => {
setKylesophyEnabled(checked);
setLocalKylesophy(checked);
}}
aria-label="Enable Kyleosophy"
disabled={kylesophyForced}
/>
</div>
</CollapsibleSectionContent>
</CollapsibleSection>
{showDefaultPromptPreview && (
<TextPreviewDialog
content={defaultSystemPrompt}
@@ -555,6 +603,6 @@ export const AgentSettingsBehaviorPageView: FC<
onClose={() => setShowDefaultPromptPreview(false)}
/>
)}
</>
</div>
);
};
@@ -1,63 +0,0 @@
import { BookOpenIcon, LoaderIcon, TriangleAlertIcon } from "lucide-react";
import type React from "react";
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "#/components/Tooltip/Tooltip";
import { cn } from "#/utils/cn";
import { Response } from "../Response";
import { ToolCollapsible } from "./ToolCollapsible";
import type { ToolStatus } from "./utils";
export const ReadSkillTool: React.FC<{
label: string;
body: string;
status: ToolStatus;
isError: boolean;
errorMessage?: string;
}> = ({ label, body, status, isError, errorMessage }) => {
const hasContent = body.length > 0;
const isRunning = status === "running";
return (
<ToolCollapsible
className="w-full"
hasContent={hasContent}
header={
<>
<BookOpenIcon className="h-4 w-4 shrink-0 text-content-secondary" />
<span className={cn("text-sm", "text-content-secondary")}>
{isRunning ? `Reading ${label}` : `Read ${label}`}
</span>
{isError && (
<Tooltip>
<TooltipTrigger asChild>
<TriangleAlertIcon className="h-3.5 w-3.5 shrink-0 text-content-secondary" />
</TooltipTrigger>
<TooltipContent>
{errorMessage || "Failed to read skill"}
</TooltipContent>
</Tooltip>
)}
{isRunning && (
<LoaderIcon className="h-3.5 w-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" />
)}
</>
}
>
{body && (
<ScrollArea
className="mt-1.5 rounded-md border border-solid border-border-default"
viewportClassName="max-h-64"
scrollBarClassName="w-1.5"
>
<div className="px-3 py-2">
<Response>{body}</Response>
</div>
</ScrollArea>
)}
</ToolCollapsible>
);
};
@@ -1342,12 +1342,6 @@ export const ReadSkillCompleted: Story = {
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.getByText(/Read skill deep-review/)).toBeInTheDocument();
// Expand the collapsible to verify markdown body renders.
const toggle = canvas.getByRole("button");
await userEvent.click(toggle);
await waitFor(() => {
expect(canvas.getByText("Deep Review Skill")).toBeInTheDocument();
});
},
};
@@ -1399,12 +1393,6 @@ export const ReadSkillFileCompleted: Story = {
expect(
canvas.getByText(/Read deep-review\/roles\/security-reviewer\.md/),
).toBeInTheDocument();
// Expand the collapsible to verify markdown content renders.
const toggle = canvas.getByRole("button");
await userEvent.click(toggle);
await waitFor(() => {
expect(canvas.getByText("Security Reviewer Role")).toBeInTheDocument();
});
},
};
@@ -23,7 +23,6 @@ import { ListTemplatesTool } from "./ListTemplatesTool";
import { ProcessOutputTool } from "./ProcessOutputTool";
import { ProposePlanTool } from "./ProposePlanTool";
import { ReadFileTool } from "./ReadFileTool";
import { ReadSkillTool } from "./ReadSkillTool";
import { ReadTemplateTool } from "./ReadTemplateTool";
import { SubagentTool } from "./SubagentTool";
import { ToolCollapsible } from "./ToolCollapsible";
@@ -211,55 +210,6 @@ const ReadFileRenderer: FC<ToolRendererProps> = ({
);
};
const ReadSkillRenderer: FC<ToolRendererProps> = ({
status,
args,
result,
isError,
}) => {
const parsedArgs = parseArgs(args);
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
const rec = asRecord(result);
const body = rec ? asString(rec.body) : "";
return (
<ReadSkillTool
label={skillName ? `skill ${skillName}` : "skill"}
body={body}
status={status}
isError={isError}
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
/>
);
};
const ReadSkillFileRenderer: FC<ToolRendererProps> = ({
status,
args,
result,
isError,
}) => {
const parsedArgs = parseArgs(args);
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
const filePath = parsedArgs ? asString(parsedArgs.path) : "";
const label =
skillName && filePath
? `${skillName}/${filePath}`
: skillName || filePath || "skill file";
const rec = asRecord(result);
const content = rec ? asString(rec.content) : "";
return (
<ReadSkillTool
label={label}
body={content}
status={status}
isError={isError}
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
/>
);
};
const WriteFileRenderer: FC<ToolRendererProps> = ({
status,
args,
@@ -717,8 +667,6 @@ const toolRenderers: Record<string, FC<ToolRendererProps>> = {
create_workspace: CreateWorkspaceRenderer,
list_templates: ListTemplatesRenderer,
read_template: ReadTemplateRenderer,
read_skill: ReadSkillRenderer,
read_skill_file: ReadSkillFileRenderer,
spawn_agent: SubagentRenderer,
wait_agent: SubagentRenderer,
message_agent: SubagentRenderer,
@@ -255,7 +255,7 @@ export const ProviderAccordionCards: Story = {
expect(body.queryByText("OpenAI")).not.toBeInTheDocument();
await userEvent.click(body.getByRole("button", { name: /OpenRouter/i }));
await expect(await body.findByLabelText("Base URL")).toBeInTheDocument();
await expect(body.getByLabelText("Base URL")).toBeInTheDocument();
},
};
@@ -462,9 +462,6 @@ export const CreateAndUpdateProvider: Story = {
await waitFor(() => {
expect(args.onCreateProvider).toHaveBeenCalledTimes(1);
});
await waitFor(() => {
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
});
expect(args.onCreateProvider).toHaveBeenCalledWith(
expect.objectContaining({
provider: "openai",
@@ -285,7 +285,7 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
const modelConfigsUnavailable = modelConfigsData === null;
return (
<div className={cn("flex min-h-full flex-col space-y-3", className)}>
<div className={cn("flex min-h-full flex-col space-y-6", className)}>
{isLoading && (
<div className="flex items-center gap-1.5 text-xs text-content-secondary">
<Spinner className="h-4 w-4" loading />
@@ -1,11 +1,5 @@
import { useFormik } from "formik";
import {
ChevronDownIcon,
ChevronLeftIcon,
ChevronRightIcon,
InfoIcon,
PencilIcon,
} from "lucide-react";
import { ChevronLeftIcon, InfoIcon, PencilIcon } from "lucide-react";
import { type FC, useState } from "react";
import * as Yup from "yup";
import type * as TypesGen from "#/api/typesGenerated";
@@ -41,6 +35,13 @@ import {
} from "#/components/Tooltip/Tooltip";
import { cn } from "#/utils/cn";
import { getFormHelpers } from "#/utils/formUtils";
import {
CollapsibleSection,
CollapsibleSectionContent,
CollapsibleSectionDescription,
CollapsibleSectionHeader,
CollapsibleSectionTitle,
} from "../CollapsibleSection";
import type { ProviderState } from "./ChatModelAdminPanel";
import {
GeneralModelConfigFields,
@@ -116,9 +117,6 @@ export const ModelForm: FC<ModelFormProps> = ({
}) => {
const isEditing = Boolean(editingModel);
const isDefaultModel = isEditing && editingModel?.is_default === true;
const [showAdvanced, setShowAdvanced] = useState(false);
const [showPricing, setShowPricing] = useState(false);
const [showProviderConfig, setShowProviderConfig] = useState(false);
const [confirmingDelete, setConfirmingDelete] = useState(false);
const canManageModels = Boolean(
@@ -488,29 +486,18 @@ export const ModelForm: FC<ModelFormProps> = ({
</div>
{/* Usage Tracking */}
<div className="border-0 border-t border-solid border-border pt-4">
<button
type="button"
onClick={() => setShowPricing((v) => !v)}
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
>
<div>
<h3 className="m-0 text-sm font-medium text-content-primary">
Cost Tracking{" "}
</h3>
<p className="m-0 text-xs text-content-secondary">
Set per-token pricing so Coder can track costs and enforce
spending limits.
</p>
</div>
{showPricing ? (
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
) : (
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
)}
</button>
{showPricing && (
<div className="grid grid-cols-2 gap-3 pt-3 sm:grid-cols-4">
<CollapsibleSection variant="inline" defaultOpen={false}>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle as="h3">
Cost Tracking
</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Set per-token pricing so Coder can track costs and enforce
spending limits.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="grid grid-cols-2 gap-3 sm:grid-cols-4">
<PricingModelConfigFields
provider={selectedProviderState.provider}
form={form}
@@ -518,33 +505,21 @@ export const ModelForm: FC<ModelFormProps> = ({
disabled={isSaving}
/>
</div>
)}
</div>
</CollapsibleSectionContent>
</CollapsibleSection>
{/* Provider Configuration */}
<div className="border-0 border-t border-solid border-border pt-4">
<button
type="button"
onClick={() => setShowProviderConfig((v) => !v)}
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
>
<CollapsibleSection variant="inline" defaultOpen={false}>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle as="h3">
Provider Configuration
</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Tune provider-specific behavior like reasoning, tool calling,
and web search.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div>
<h3 className="m-0 text-sm font-medium text-content-primary">
Provider Configuration
</h3>
<p className="m-0 text-xs text-content-secondary">
Tune provider-specific behavior like reasoning, tool calling,
and web search.
</p>
</div>
{showProviderConfig ? (
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
) : (
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
)}
</button>
{showProviderConfig && (
<div className="pt-3">
<ModelConfigFields
provider={selectedProviderState.provider}
form={form}
@@ -552,33 +527,21 @@ export const ModelForm: FC<ModelFormProps> = ({
disabled={isSaving}
/>
</div>
)}
</div>
</CollapsibleSectionContent>
</CollapsibleSection>
{/* Advanced */}
<div className="border-0 border-t border-solid border-border pt-4">
<button
type="button"
onClick={() => setShowAdvanced((v) => !v)}
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
>
<div>
<h3 className="m-0 text-sm font-medium text-content-primary">
Advanced
</h3>
<p className="m-0 text-xs text-content-secondary">
Low-level parameters like temperature and penalties. Rarely
need changing.
</p>
</div>
{showAdvanced ? (
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
) : (
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
)}
</button>
{showAdvanced && (
<div className="grid grid-cols-2 gap-3 pt-3 sm:grid-cols-3">
<CollapsibleSection variant="inline" defaultOpen={false}>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle as="h3">
Advanced
</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Low-level parameters like temperature and penalties. Rarely need
changing.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="grid grid-cols-2 gap-3 sm:grid-cols-3">
<GeneralModelConfigFields
provider={selectedProviderState.provider}
form={form}
@@ -629,10 +592,11 @@ export const ModelForm: FC<ModelFormProps> = ({
)}
</div>
</div>
)}
</div>
</CollapsibleSectionContent>
</CollapsibleSection>
</div>
<div className="mt-auto py-6">
{" "}
<hr className="mb-4 border-0 border-t border-solid border-border" />
<div className="flex items-center justify-between">
{isEditing && editingModel && onDeleteModel ? (
@@ -0,0 +1,193 @@
import type { Meta, StoryObj } from "@storybook/react-vite";
import { expect, userEvent, within } from "storybook/test";
import { AdminBadge } from "./AdminBadge";
import {
CollapsibleSection,
CollapsibleSectionContent,
CollapsibleSectionDescription,
CollapsibleSectionHeader,
CollapsibleSectionTitle,
} from "./CollapsibleSection";
const meta: Meta<typeof CollapsibleSection> = {
title: "pages/AgentsPage/CollapsibleSection",
component: CollapsibleSection,
decorators: [
(Story) => (
<div style={{ maxWidth: 600 }}>
<Story />
</div>
),
],
};
export default meta;
type Story = StoryObj<typeof CollapsibleSection>;
const Placeholder = () => (
<p className="text-sm text-content-secondary">Placeholder content</p>
);
export const DefaultOpen: Story = {
render: () => (
<CollapsibleSection>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>Default spend limit</CollapsibleSectionTitle>
<AdminBadge />
</div>
<CollapsibleSectionDescription>
The deployment-wide spending cap.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
const header = canvas.getByRole("button", {
name: /Default spend limit/i,
});
expect(header).toHaveAttribute("aria-expanded", "true");
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
},
};
export const Collapsed: Story = {
render: () => (
<CollapsibleSection defaultOpen={false}>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>Default spend limit</CollapsibleSectionTitle>
<AdminBadge />
</div>
<CollapsibleSectionDescription>
The deployment-wide spending cap.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
const header = canvas.getByRole("button", {
name: /Default spend limit/i,
});
expect(header).toHaveAttribute("aria-expanded", "false");
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
},
};
export const NoBadge: Story = {
render: () => (
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Group limits</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Override defaults for groups.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
};
export const InlineVariant: Story = {
render: () => (
<CollapsibleSection variant="inline" defaultOpen={false}>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle as="h3">Cost Tracking</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Set per-token pricing so Coder can track costs and enforce spending
limits.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
const header = canvas.getByRole("button", {
name: /Cost Tracking/i,
});
expect(header).toHaveAttribute("aria-expanded", "false");
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
},
};
export const KeyboardToggle: Story = {
render: () => (
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Keyboard section</CollapsibleSectionTitle>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const header = canvas.getByRole("button", {
name: /Keyboard section/i,
});
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
header.focus();
await userEvent.keyboard("{Enter}");
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
await userEvent.keyboard("{Enter}");
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
await userEvent.keyboard(" ");
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
await userEvent.keyboard(" ");
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
},
};
@@ -0,0 +1,211 @@
import { cva, type VariantProps } from "class-variance-authority";
import { ChevronDownIcon } from "lucide-react";
import {
type ComponentProps,
createContext,
type FC,
type ReactNode,
useContext,
useState,
} from "react";
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from "#/components/Collapsible/Collapsible";
import { cn } from "#/utils/cn";
// ---------------------------------------------------------------------------
// Variant styles
// ---------------------------------------------------------------------------
const wrapperStyles = cva("", {
variants: {
variant: {
card: "rounded-lg border border-solid border-border-default",
inline: "border-0 border-t border-solid border-border-default pt-4",
},
},
defaultVariants: { variant: "card" },
});
const triggerStyles = cva(
"flex w-full cursor-pointer items-start justify-between gap-4 border-0 bg-transparent text-left focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-content-link",
{
variants: {
variant: {
card: "rounded-lg px-6 py-5",
inline: "rounded-md p-0",
},
},
defaultVariants: { variant: "card" },
},
);
const titleStyles = cva("m-0 text-content-primary", {
variants: {
variant: {
card: "text-lg font-semibold",
inline: "text-sm font-medium",
},
},
defaultVariants: { variant: "card" },
});
const descriptionStyles = cva("m-0 text-content-secondary", {
variants: {
variant: {
card: "mt-1 text-sm",
inline: "text-xs",
},
},
defaultVariants: { variant: "card" },
});
const contentStyles = cva("", {
variants: {
variant: {
card: "border-0 border-t border-solid border-border-default px-6 pb-5 pt-4",
inline: "pt-3",
},
},
defaultVariants: { variant: "card" },
});
// ---------------------------------------------------------------------------
// Context
// ---------------------------------------------------------------------------
type Variant = "card" | "inline";
interface CollapsibleSectionContextValue {
variant: Variant;
open: boolean;
}
const CollapsibleSectionContext = createContext<CollapsibleSectionContextValue>(
{
variant: "card",
open: true,
},
);
const useCollapsibleSection = () => useContext(CollapsibleSectionContext);
// ---------------------------------------------------------------------------
// Root
// ---------------------------------------------------------------------------
interface CollapsibleSectionProps extends VariantProps<typeof wrapperStyles> {
defaultOpen?: boolean;
/** Controlled open state. */
open?: boolean;
onOpenChange?: (open: boolean) => void;
children: ReactNode;
}
export const CollapsibleSection: FC<CollapsibleSectionProps> = ({
defaultOpen,
open: controlledOpen,
onOpenChange: controlledOnOpenChange,
variant = "card",
children,
}) => {
const [uncontrolledOpen, setUncontrolledOpen] = useState(defaultOpen ?? true);
const isControlled = controlledOpen !== undefined;
const open = isControlled ? controlledOpen : uncontrolledOpen;
const onOpenChange = isControlled
? controlledOnOpenChange
: setUncontrolledOpen;
return (
<CollapsibleSectionContext.Provider
value={{ variant: variant ?? "card", open }}
>
<Collapsible open={open} onOpenChange={onOpenChange}>
<div className={wrapperStyles({ variant })}>{children}</div>
</Collapsible>
</CollapsibleSectionContext.Provider>
);
};
// ---------------------------------------------------------------------------
// Header (trigger)
// ---------------------------------------------------------------------------
interface CollapsibleSectionHeaderProps {
children: ReactNode;
}
export const CollapsibleSectionHeader: FC<CollapsibleSectionHeaderProps> = ({
children,
}) => {
const { variant, open } = useCollapsibleSection();
return (
<CollapsibleTrigger className={triggerStyles({ variant })}>
<div className="min-w-0 flex-1">{children}</div>
<ChevronDownIcon
className={cn(
"h-4 w-4 shrink-0 text-content-secondary transition-transform duration-200",
open && "rotate-180",
)}
/>
</CollapsibleTrigger>
);
};
// ---------------------------------------------------------------------------
// Title — renders the heading element at whatever level the consumer picks.
// ---------------------------------------------------------------------------
type HeadingTag = "h1" | "h2" | "h3" | "h4" | "h5" | "h6";
interface CollapsibleSectionTitleProps extends ComponentProps<HeadingTag> {
as?: HeadingTag;
}
export const CollapsibleSectionTitle: FC<CollapsibleSectionTitleProps> = ({
as: Component = "h2",
className,
...props
}) => {
const { variant } = useCollapsibleSection();
return (
<Component className={cn(titleStyles({ variant }), className)} {...props} />
);
};
// ---------------------------------------------------------------------------
// Description
// ---------------------------------------------------------------------------
export const CollapsibleSectionDescription: FC<ComponentProps<"p">> = ({
className,
...props
}) => {
const { variant } = useCollapsibleSection();
return (
<p className={cn(descriptionStyles({ variant }), className)} {...props} />
);
};
// ---------------------------------------------------------------------------
// Content
// ---------------------------------------------------------------------------
interface CollapsibleSectionContentProps {
children: ReactNode;
}
export const CollapsibleSectionContent: FC<CollapsibleSectionContentProps> = ({
children,
}) => {
const { variant } = useCollapsibleSection();
return (
<CollapsibleContent>
<div className={contentStyles({ variant })}>{children}</div>
</CollapsibleContent>
);
};
@@ -989,7 +989,7 @@ export const MCPServerAdminPanel: FC<MCPServerAdminPanelProps> = ({
}
return (
<div className="flex min-h-full flex-col space-y-3">
<div className="flex min-h-full flex-col space-y-6">
{!isFormView ? (
<ServerList
servers={servers}
@@ -21,7 +21,9 @@ import {
} from "#/components/Table/Table";
import { cn } from "#/utils/cn";
import { formatCostMicros } from "#/utils/currency";
import { AdminBadge } from "./AdminBadge";
import { PrStateIcon } from "./GitPanel/GitPanel";
import { SectionHeader } from "./SectionHeader";
dayjs.extend(relativeTime);
@@ -295,19 +297,16 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
const isEmpty = summary.total_prs_created === 0;
return (
<div className="space-y-8">
<div className="space-y-6">
{/* ── Header ── */}
<div className="flex items-end justify-between">
<div>
<h2 className="m-0 text-xl font-semibold tracking-tight text-content-primary">
Pull Request Insights
</h2>
<p className="m-0 mt-1 text-[13px] text-content-secondary">
Code changes detected by Agents.
</p>
</div>
<TimeRangeFilter value={timeRange} onChange={onTimeRangeChange} />
</div>
<SectionHeader
label="Pull Request Insights"
description="Code changes detected by Agents."
badge={<AdminBadge />}
action={
<TimeRangeFilter value={timeRange} onChange={onTimeRangeChange} />
}
/>
{isEmpty ? (
<EmptyState />
@@ -355,7 +354,7 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
</section>
{/* ── Model breakdown + Recent PRs side by side ── */}
<div className="grid grid-cols-1 gap-6 lg:grid-cols-2">
<div className="grid grid-cols-1 gap-6">
{/* ── Model performance (simplified) ── */}
{by_model.length > 0 && (
<section>
@@ -15,8 +15,8 @@ export const SectionHeader: FC<SectionHeaderProps> = ({
}) => (
<>
<div className="flex items-start justify-between gap-4">
<div>
<div className="flex w-full items-center gap-2">
<div className="min-w-0 flex-1">
<div className="flex items-center gap-2">
<h2 className="m-0 text-lg font-medium text-content-primary">
{label}
</h2>
@@ -78,11 +78,8 @@ export const Default: Story = {
expect(canvas.getByText("Claude Sonnet")).toBeInTheDocument();
expect(canvas.queryByText("GPT-3.5 (Disabled)")).not.toBeInTheDocument();
// No footer visible when nothing is dirty
expect(
canvas.queryByRole("button", { name: /Save/i }),
).not.toBeInTheDocument();
// Save button should be disabled when nothing is dirty.
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
// Type a value to make the footer appear
await userEvent.type(gpt4oInput, "95");
await waitFor(() => {
@@ -163,11 +160,9 @@ export const CancelChanges: Story = {
const cancelButton = await canvas.findByRole("button", { name: /Cancel/i });
await userEvent.click(cancelButton);
// Footer should disappear after cancel
// Save button should be disabled after cancel.
await waitFor(() => {
expect(
canvas.queryByRole("button", { name: /Save/i }),
).not.toBeInTheDocument();
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
});
// Input should be cleared back to empty (no override)
@@ -194,10 +189,8 @@ export const InvalidDraftShowsFooter: Story = {
// Cancel button should be visible so user can discard the edit
expect(canvas.getByRole("button", { name: /Cancel/i })).toBeInTheDocument();
// Save button should NOT be visible (nothing valid to save)
expect(
canvas.queryByRole("button", { name: /Save/i }),
).not.toBeInTheDocument();
// Save button should be disabled (nothing valid to save).
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
},
};
@@ -9,7 +9,6 @@ import {
Table,
TableBody,
TableCell,
TableFooter,
TableHead,
TableHeader,
TableRow,
@@ -33,6 +32,7 @@ interface UserCompactionThresholdSettingsProps {
thresholdPercent: number,
) => Promise<unknown>;
onResetThreshold: (modelConfigId: string) => Promise<unknown>;
hideHeader?: boolean;
}
const parseThresholdDraft = (value: string): number | null => {
@@ -60,6 +60,7 @@ export const UserCompactionThresholdSettings: FC<
thresholdsError,
onSaveThreshold,
onResetThreshold,
hideHeader,
}) => {
const [drafts, setDrafts] = useState<Record<string, string>>({});
const [rowErrors, setRowErrors] = useState<Record<string, string>>({});
@@ -172,19 +173,23 @@ export const UserCompactionThresholdSettings: FC<
};
const hasAnyPending = pendingModels.size > 0;
const hasAnyErrors = Object.keys(rowErrors).length > 0;
const hasAnyDrafts = Object.keys(drafts).length > 0;
const headerBlock = !hideHeader ? (
<>
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Context Compaction
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Control when conversation context is automatically summarized for each
model. Setting 100% means the conversation will never auto-compact.
</p>
</>
) : null;
if (isThresholdsLoading) {
return (
<div className="space-y-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Context Compaction
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Control when conversation context is automatically summarized for each
model. Setting 100% means the conversation will never auto-compact.
</p>
{headerBlock}
<div className="flex items-center gap-2 text-sm text-content-secondary">
<Spinner loading className="h-4 w-4" />
Loading thresholds...
@@ -196,13 +201,7 @@ export const UserCompactionThresholdSettings: FC<
if (thresholdsError != null) {
return (
<div className="space-y-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Context Compaction
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Control when conversation context is automatically summarized for each
model. Setting 100% means the conversation will never auto-compact.
</p>
{headerBlock}
<p className="m-0 text-xs text-content-destructive">
{getErrorMessage(
thresholdsError,
@@ -215,13 +214,7 @@ export const UserCompactionThresholdSettings: FC<
return (
<div className="space-y-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Context Compaction
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Control when conversation context is automatically summarized for each
model. Setting 100% means the conversation will never auto-compact.
</p>
{headerBlock}
{isLoadingModelConfigs ? (
<div className="flex items-center gap-2 text-sm text-content-secondary">
<Spinner loading className="h-4 w-4" />
@@ -240,165 +233,163 @@ export const UserCompactionThresholdSettings: FC<
models before compaction thresholds can be set.
</p>
) : (
<Table>
<TableHeader>
<TableRow>
<TableHead>Model</TableHead>
<TableHead className="w-0 whitespace-nowrap">Default</TableHead>
<TableHead className="w-0 whitespace-nowrap">Threshold</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{enabledModelConfigs.map((modelConfig) => {
const existingOverride = overridesByModelID.get(modelConfig.id);
const hasOverride = overridesByModelID.has(modelConfig.id);
const draftValue =
drafts[modelConfig.id] ??
(existingOverride !== undefined
? String(existingOverride)
: "");
const parsedDraftValue = parseThresholdDraft(draftValue);
const isThisModelMutating = pendingModels.has(modelConfig.id);
const isInvalid =
draftValue.length > 0 && parsedDraftValue === null;
// Only warn when user-typed, not when loaded from
// the server.
const isDraftDisablingCompaction =
draftValue === "100" && drafts[modelConfig.id] !== undefined;
const rowError = rowErrors[modelConfig.id];
const modelName = modelConfig.display_name || modelConfig.model;
return (
<TableRow key={modelConfig.id}>
<TableCell className="text-[13px] font-medium text-content-primary">
{modelName}
{rowError && (
<p
aria-live="polite"
className="m-0 mt-0.5 text-2xs font-normal text-content-destructive"
>
{rowError}
</p>
)}
</TableCell>
<TableCell className="w-0 whitespace-nowrap tabular-nums">
{modelConfig.compression_threshold}%
</TableCell>
<TableCell className="w-0 whitespace-nowrap">
<div className="flex items-center gap-1.5">
<Tooltip>
<TooltipTrigger asChild>
<Input
aria-label={`${modelName} compaction threshold`}
aria-invalid={isInvalid || undefined}
type="number"
min={0}
max={100}
inputMode="numeric"
className={cn(
"h-7 w-16 px-2 text-xs tabular-nums",
isInvalid &&
"border-content-destructive focus:ring-content-destructive/30",
)}
value={draftValue}
placeholder={String(
modelConfig.compression_threshold,
)}
onChange={(event) => {
setDrafts((currentDrafts) => ({
...currentDrafts,
[modelConfig.id]: event.target.value,
}));
clearRowError(modelConfig.id);
}}
disabled={isThisModelMutating}
/>
</TooltipTrigger>
{(isInvalid || isDraftDisablingCompaction) && (
<TooltipContent>
{isInvalid
? "Enter a whole number between 0 and 100."
: "Setting 100% will disable auto-compaction for this model."}
</TooltipContent>
)}
</Tooltip>
<span className="text-xs text-content-secondary">%</span>
<Tooltip>
<TooltipTrigger asChild>
<Button
size="icon"
variant="subtle"
className={cn(
"size-7",
hasOverride
? "opacity-100"
: "pointer-events-none opacity-0",
)}
aria-label={`Reset ${modelName} to default`}
aria-hidden={!hasOverride}
tabIndex={hasOverride ? 0 : -1}
disabled={isThisModelMutating || !hasOverride}
onClick={() => handleReset(modelConfig.id)}
>
<RotateCcwIcon className="size-3.5" />
</Button>
</TooltipTrigger>
{hasOverride && (
<TooltipContent>
Reset to default (
{modelConfig.compression_threshold}%)
</TooltipContent>
)}
</Tooltip>
</div>
{isInvalid && (
<span className="sr-only" aria-live="polite">
Enter a whole number between 0 and 100.
</span>
)}
{isDraftDisablingCompaction && (
<span className="sr-only" aria-live="polite">
Setting 100% will disable auto-compaction for this
model.
</span>
)}
</TableCell>
</TableRow>
);
})}
</TableBody>
{(dirtyRows.length > 0 || hasAnyErrors || hasAnyDrafts) && (
<TableFooter className="bg-transparent">
<TableRow className="border-0">
<TableCell colSpan={3} className="border-0 p-0">
<div className="flex items-center justify-end gap-2 px-3 py-1.5">
<Button
size="sm"
variant="outline"
type="button"
onClick={handleCancelAll}
disabled={hasAnyPending}
>
Cancel
</Button>
{dirtyRows.length > 0 && (
<Button
size="sm"
type="button"
disabled={hasAnyPending}
onClick={handleSaveAll}
>
{hasAnyPending
? "Saving..."
: `Save ${dirtyRows.length} ${dirtyRows.length === 1 ? "change" : "changes"}`}
</Button>
)}
</div>
</TableCell>
<>
<Table>
<TableHeader>
<TableRow>
<TableHead>Model</TableHead>
<TableHead className="w-0 whitespace-nowrap">Default</TableHead>
<TableHead className="w-0 whitespace-nowrap">
Threshold
</TableHead>
</TableRow>
</TableFooter>
)}
</Table>
</TableHeader>
<TableBody>
{enabledModelConfigs.map((modelConfig) => {
const existingOverride = overridesByModelID.get(modelConfig.id);
const hasOverride = overridesByModelID.has(modelConfig.id);
const draftValue =
drafts[modelConfig.id] ??
(existingOverride !== undefined
? String(existingOverride)
: "");
const parsedDraftValue = parseThresholdDraft(draftValue);
const isThisModelMutating = pendingModels.has(modelConfig.id);
const isInvalid =
draftValue.length > 0 && parsedDraftValue === null;
// Only warn when user-typed, not when loaded from
// the server.
const isDraftDisablingCompaction =
draftValue === "100" && drafts[modelConfig.id] !== undefined;
const rowError = rowErrors[modelConfig.id];
const modelName = modelConfig.display_name || modelConfig.model;
return (
<TableRow key={modelConfig.id}>
<TableCell className="text-[13px] font-medium text-content-primary">
{modelName}
{rowError && (
<p
aria-live="polite"
className="m-0 mt-0.5 text-2xs font-normal text-content-destructive"
>
{rowError}
</p>
)}
</TableCell>
<TableCell className="w-0 whitespace-nowrap tabular-nums">
{modelConfig.compression_threshold}%
</TableCell>
<TableCell className="w-0 whitespace-nowrap">
<div className="flex items-center gap-1.5">
<Tooltip>
<TooltipTrigger asChild>
<Input
aria-label={`${modelName} compaction threshold`}
aria-invalid={isInvalid || undefined}
type="number"
min={0}
max={100}
inputMode="numeric"
className={cn(
"h-7 w-16 px-2 text-xs tabular-nums",
isInvalid &&
"border-content-destructive focus:ring-content-destructive/30",
)}
value={draftValue}
placeholder={String(
modelConfig.compression_threshold,
)}
onChange={(event) => {
setDrafts((currentDrafts) => ({
...currentDrafts,
[modelConfig.id]: event.target.value,
}));
clearRowError(modelConfig.id);
}}
disabled={isThisModelMutating}
/>
</TooltipTrigger>
{(isInvalid || isDraftDisablingCompaction) && (
<TooltipContent>
{isInvalid
? "Enter a whole number between 0 and 100."
: "Setting 100% will disable auto-compaction for this model."}
</TooltipContent>
)}
</Tooltip>
<span className="text-xs text-content-secondary">
%
</span>
<Tooltip>
<TooltipTrigger asChild>
<Button
size="icon"
variant="subtle"
className={cn(
"size-7",
hasOverride
? "opacity-100"
: "pointer-events-none opacity-0",
)}
aria-label={`Reset ${modelName} to default`}
aria-hidden={!hasOverride}
tabIndex={hasOverride ? 0 : -1}
disabled={isThisModelMutating || !hasOverride}
onClick={() => handleReset(modelConfig.id)}
>
<RotateCcwIcon className="size-3.5" />
</Button>
</TooltipTrigger>
{hasOverride && (
<TooltipContent>
Reset to default (
{modelConfig.compression_threshold}%)
</TooltipContent>
)}
</Tooltip>
</div>
{isInvalid && (
<span className="sr-only" aria-live="polite">
Enter a whole number between 0 and 100.
</span>
)}
{isDraftDisablingCompaction && (
<span className="sr-only" aria-live="polite">
Setting 100% will disable auto-compaction for this
model.
</span>
)}
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
<div className="flex justify-end gap-2 pt-2">
<Button
size="sm"
variant="outline"
type="button"
onClick={handleCancelAll}
disabled={hasAnyPending}
>
Cancel
</Button>
<Button
size="sm"
type="button"
disabled={hasAnyPending || dirtyRows.length === 0}
onClick={handleSaveAll}
>
{hasAnyPending
? "Saving..."
: dirtyRows.length > 0
? `Save ${dirtyRows.length} ${dirtyRows.length === 1 ? "change" : "changes"}`
: "Save"}
</Button>
</div>
</>
)}
</div>
);
@@ -0,0 +1,35 @@
import type { DateRangeValue } from "#/components/DateRangePicker/DateRangePicker";
/**
* Returns true when the given date falls exactly on midnight
* (00:00:00.000). Date-range pickers use midnight of the *following*
* day as the exclusive upper bound for a full-day selection. Detecting
* this lets call sites subtract 1 ms (or 1 day) so the UI shows the
* inclusive end date instead.
*/
function isMidnight(date: Date): boolean {
return (
date.getHours() === 0 &&
date.getMinutes() === 0 &&
date.getSeconds() === 0 &&
date.getMilliseconds() === 0
);
}
/**
* When the user picks an explicit date range whose end boundary is
* midnight of the following day, adjust it by 1 ms so the
* DateRangePicker highlights the inclusive end date.
*/
export function toInclusiveDateRange(
dateRange: DateRangeValue,
hasExplicitDateRange: boolean,
): DateRangeValue {
if (hasExplicitDateRange && isMidnight(dateRange.endDate)) {
return {
startDate: dateRange.startDate,
endDate: new Date(dateRange.endDate.getTime() - 1),
};
}
return dateRange;
}
+3 -33
View File
@@ -71,7 +71,6 @@ describe("AuditPage", () => {
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog, MockAuditLog2],
count: 2,
count_cap: 0,
});
// When
@@ -91,7 +90,6 @@ describe("AuditPage", () => {
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog],
count: 1,
count_cap: 0,
});
await renderPage();
@@ -116,7 +114,6 @@ describe("AuditPage", () => {
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog],
count: 1,
count_cap: 0,
});
await renderPage();
@@ -143,11 +140,9 @@ describe("AuditPage", () => {
describe("Filtering", () => {
it("filters by URL", async () => {
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog],
count: 1,
count_cap: 0,
});
const getAuditLogsSpy = vi
.spyOn(API, "getAuditLogs")
.mockResolvedValue({ audit_logs: [MockAuditLog], count: 1 });
const query = "resource_type:workspace action:create";
await renderPage({ filter: query });
@@ -178,29 +173,4 @@ describe("AuditPage", () => {
);
});
});
describe("Capped count", () => {
it("shows capped count indicator and navigates to next page with correct offset", async () => {
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog, MockAuditLog2],
count: 2001,
count_cap: 2000,
});
const user = userEvent.setup();
await renderPage();
await screen.findByText(/2,000\+/);
await user.click(screen.getByRole("button", { name: /next page/i }));
await waitFor(() =>
expect(API.getAuditLogs).toHaveBeenLastCalledWith<[AuditLogsRequest]>({
limit: DEFAULT_RECORDS_PER_PAGE,
offset: DEFAULT_RECORDS_PER_PAGE,
q: "",
}),
);
});
});
});
@@ -69,7 +69,6 @@ describe("ConnectionLogPage", () => {
MockDisconnectedSSHConnectionLog,
],
count: 2,
count_cap: 0,
});
// When
@@ -96,7 +95,6 @@ describe("ConnectionLogPage", () => {
.mockResolvedValue({
connection_logs: [MockConnectedSSHConnectionLog],
count: 1,
count_cap: 0,
});
const query = "type:ssh status:ongoing";
+42 -50
View File
@@ -12,8 +12,6 @@ import (
"sync"
"time"
"golang.org/x/sync/singleflight"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"github.com/tailscale/wireguard-go/tun"
@@ -465,8 +463,6 @@ type Conn struct {
trafficStats *connstats.Statistics
lastNetInfo *tailcfg.NetInfo
awaitReachableGroup singleflight.Group
}
func (c *Conn) GetNetInfo() *tailcfg.NetInfo {
@@ -603,60 +599,56 @@ func (c *Conn) DERPMap() *tailcfg.DERPMap {
// address is reachable. It's the callers responsibility to provide
// a timeout, otherwise this function will block forever.
func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
result, _, _ := c.awaitReachableGroup.Do(ip.String(), func() (interface{}, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel() // Cancel all pending pings on exit.
ctx, cancel := context.WithCancel(ctx)
defer cancel() // Cancel all pending pings on exit.
completedCtx, completed := context.WithCancel(context.Background())
defer completed()
completedCtx, completed := context.WithCancel(context.Background())
defer completed()
run := func() {
// Safety timeout, initially we'll have around 10-20 goroutines
// running in parallel. The exponential backoff will converge
// around ~1 ping / 30s, this means we'll have around 10-20
// goroutines pending towards the end as well.
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
run := func() {
// Safety timeout, initially we'll have around 10-20 goroutines
// running in parallel. The exponential backoff will converge
// around ~1 ping / 30s, this means we'll have around 10-20
// goroutines pending towards the end as well.
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
// For reachability, we use TSMP ping, which pings at the IP layer,
// and therefore requires that wireguard and the netstack are up.
// If we don't wait for wireguard to be up, we could miss a
// handshake, and it might take 5 seconds for the handshake to be
// retried. A 5s initial round trip can set us up for poor TCP
// performance, since the initial round-trip-time sets the initial
// retransmit timeout.
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
if err == nil {
completed()
}
// For reachability, we use TSMP ping, which pings at the IP layer, and
// therefore requires that wireguard and the netstack are up. If we
// don't wait for wireguard to be up, we could miss a handshake, and it
// might take 5 seconds for the handshake to be retried. A 5s initial
// round trip can set us up for poor TCP performance, since the initial
// round-trip-time sets the initial retransmit timeout.
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
if err == nil {
completed()
}
}
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0
eb.InitialInterval = 50 * time.Millisecond
eb.MaxInterval = 5 * time.Second
// Consume the first interval since
// we'll fire off a ping immediately.
_ = eb.NextBackOff()
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0
eb.InitialInterval = 50 * time.Millisecond
eb.MaxInterval = 30 * time.Second
// Consume the first interval since
// we'll fire off a ping immediately.
_ = eb.NextBackOff()
t := backoff.NewTicker(eb)
defer t.Stop()
t := backoff.NewTicker(eb)
defer t.Stop()
go run()
for {
select {
case <-completedCtx.Done():
return true, nil
case <-t.C:
// Pings can take a while, so we can run multiple
// in parallel to return ASAP.
go run()
case <-ctx.Done():
return false, nil
}
go run()
for {
select {
case <-completedCtx.Done():
return true
case <-t.C:
// Pings can take a while, so we can run multiple
// in parallel to return ASAP.
go run()
case <-ctx.Done():
return false
}
})
return result.(bool)
}
}
// Closed is a channel that ends when the connection has