Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 37b0cbbcb7 | |||
| 669328ebd9 | |||
| de846d8ea1 |
Generated
-6
@@ -14175,9 +14175,6 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -14499,9 +14496,6 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
Generated
-6
@@ -12739,9 +12739,6 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13042,9 +13039,6 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
+1
-8
@@ -26,11 +26,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Limit the count query to avoid a slow sequential scan due to joins
|
||||
// on a large table. Set to 0 to disable capping (but also see the note
|
||||
// in the SQL query).
|
||||
const auditLogCountCap = 2000
|
||||
|
||||
// @Summary Get audit logs
|
||||
// @ID get-audit-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -71,7 +66,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
countFilter.Username = ""
|
||||
}
|
||||
|
||||
countFilter.CountCap = auditLogCountCap
|
||||
// Use the same filters to count the number of audit logs
|
||||
count, err := api.Database.CountAuditLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -86,7 +81,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: []codersdk.AuditLog{},
|
||||
Count: 0,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -104,7 +98,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: api.convertAuditLogs(ctx, dblogs),
|
||||
Count: count,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2155,12 +2155,17 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecret(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -4123,6 +4128,19 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
|
||||
return q.db.GetUserNotificationPreferences(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
@@ -5506,7 +5524,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
|
||||
return q.db.ListUserChatCompactionThresholds(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
return nil, err
|
||||
@@ -5514,16 +5532,6 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
|
||||
return q.db.ListUserSecrets(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
// This query returns decrypted secret values and must only be called
|
||||
// from system contexts (provisioner, agent manifest). REST API
|
||||
// handlers should use ListUserSecrets (metadata only).
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListUserSecretsWithValues(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
@@ -5774,15 +5782,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
// The batch heartbeat is a system-level operation filtered by
|
||||
// worker_id. Authorization is enforced by the AsChatd context
|
||||
// at the call site rather than per-row, because checking each
|
||||
// row individually would defeat the purpose of batching.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.UpdateChatHeartbeats(ctx, arg)
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
|
||||
@@ -6624,12 +6632,17 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
|
||||
return q.db.UpdateUserRoles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
|
||||
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return q.db.UpdateUserSecret(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
|
||||
|
||||
@@ -842,15 +842,15 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
resultID := uuid.New()
|
||||
arg := database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{resultID},
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
}))
|
||||
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -5346,20 +5346,19 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns([]database.ListUserSecretsRow{row})
|
||||
}))
|
||||
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceSystem, policy.ActionRead).
|
||||
Returns([]database.UserSecret{secret})
|
||||
}))
|
||||
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
@@ -5371,21 +5370,22 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
|
||||
Returns(ret)
|
||||
}))
|
||||
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
|
||||
arg := database.UpdateUserSecretParams{ID: secret.ID}
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
|
||||
Asserts(secret, policy.ActionUpdate).
|
||||
Returns(updated)
|
||||
}))
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
|
||||
Returns()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -1597,7 +1597,6 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
|
||||
Name: takeFirst(seed.Name, "secret-name"),
|
||||
Description: takeFirst(seed.Description, "secret description"),
|
||||
Value: takeFirst(seed.Value, "secret value"),
|
||||
ValueKeyID: seed.ValueKeyID,
|
||||
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
|
||||
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
|
||||
})
|
||||
|
||||
@@ -712,11 +712,11 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -2624,6 +2624,14 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
|
||||
@@ -3912,7 +3920,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecrets(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
|
||||
@@ -3920,14 +3928,6 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
|
||||
@@ -4136,11 +4136,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
|
||||
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4696,11 +4696,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
|
||||
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
|
||||
@@ -1199,18 +1199,18 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
|
||||
@@ -4907,6 +4907,21 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserSecret mocks base method.
|
||||
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserSecret indicates an expected call of GetUserSecret.
|
||||
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
|
||||
}
|
||||
|
||||
// GetUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7397,10 +7412,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
|
||||
}
|
||||
|
||||
// ListUserSecrets mocks base method.
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.ListUserSecretsRow)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -7411,21 +7426,6 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues mocks base method.
|
||||
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
|
||||
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
|
||||
}
|
||||
|
||||
// ListWorkspaceAgentPortShares mocks base method.
|
||||
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7835,19 +7835,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeats mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
// UpdateChatHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
|
||||
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLabelsByID mocks base method.
|
||||
@@ -8854,19 +8854,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
// UpdateUserSecret mocks base method.
|
||||
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserStatus mocks base method.
|
||||
|
||||
@@ -584,7 +584,6 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -721,7 +720,6 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
@@ -145,13 +145,5 @@ func extractWhereClause(query string) string {
|
||||
// Remove SQL comments
|
||||
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
|
||||
|
||||
// Normalize indentation so subquery wrapping doesn't cause
|
||||
// mismatches.
|
||||
lines := strings.Split(whereClause, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = strings.TrimLeft(line, " \t")
|
||||
}
|
||||
whereClause = strings.Join(lines, "\n")
|
||||
|
||||
return strings.TrimSpace(whereClause)
|
||||
}
|
||||
|
||||
@@ -81,8 +81,8 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
|
||||
}
|
||||
|
||||
func (q *msgQueue) run() {
|
||||
var batch [maxDrainBatch]msgOrErr
|
||||
for {
|
||||
// wait until there is something on the queue or we are closed
|
||||
q.cond.L.Lock()
|
||||
for q.size == 0 && !q.closed {
|
||||
q.cond.Wait()
|
||||
@@ -91,32 +91,28 @@ func (q *msgQueue) run() {
|
||||
q.cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
// Drain up to maxDrainBatch items while holding the lock.
|
||||
n := min(q.size, maxDrainBatch)
|
||||
for i := range n {
|
||||
batch[i] = q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
}
|
||||
q.size -= n
|
||||
item := q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
q.size--
|
||||
q.cond.L.Unlock()
|
||||
|
||||
// Dispatch each message individually without holding the lock.
|
||||
for i := range n {
|
||||
item := batch[i]
|
||||
if item.err == nil {
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
}
|
||||
// process item without holding lock
|
||||
if item.err == nil {
|
||||
// real message
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, nil, item.err)
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
}
|
||||
// unhittable
|
||||
continue
|
||||
}
|
||||
// if the listener wants errors, send it.
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, nil, item.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -237,12 +233,6 @@ type PGPubsub struct {
|
||||
// for a subscriber before dropping messages.
|
||||
const BufferSize = 2048
|
||||
|
||||
// maxDrainBatch is the maximum number of messages to drain from the ring
|
||||
// buffer per iteration. Batching amortizes the cost of mutex
|
||||
// acquire/release and cond.Wait across many messages, improving drain
|
||||
// throughput during bursts.
|
||||
const maxDrainBatch = 256
|
||||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
|
||||
|
||||
@@ -152,7 +152,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -598,6 +598,7 @@ type sqlcQuerier interface {
|
||||
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
|
||||
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
|
||||
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
|
||||
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
|
||||
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
// GetUserStatusCounts returns the count of users in each status over time.
|
||||
// The time range is inclusively defined by the start_time and end_time parameters.
|
||||
@@ -817,13 +818,7 @@ type sqlcQuerier interface {
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
|
||||
// Returns metadata only (no value or value_key_id) for the
|
||||
// REST API list and get endpoints.
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
|
||||
// Returns all columns including the secret value. Used by the
|
||||
// provisioner (build-time injection) and the agent manifest
|
||||
// (runtime injection).
|
||||
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
|
||||
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
|
||||
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
|
||||
@@ -875,11 +870,9 @@ type sqlcQuerier interface {
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
// Updates the cached injected context parts (AGENTS.md +
|
||||
// skills) on the chat row. Called only when context changes
|
||||
@@ -962,7 +955,7 @@ type sqlcQuerier interface {
|
||||
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
|
||||
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
|
||||
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
|
||||
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
|
||||
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
|
||||
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
|
||||
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
|
||||
|
||||
@@ -7339,7 +7339,13 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, secretID, createdSecret.ID)
|
||||
|
||||
// 2. READ by UserID and Name
|
||||
// 2. READ by ID
|
||||
readSecret, err := db.GetUserSecret(ctx, createdSecret.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readSecret.Name)
|
||||
|
||||
// 3. READ by UserID and Name
|
||||
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
@@ -7347,43 +7353,33 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
|
||||
|
||||
// 3. LIST (metadata only)
|
||||
// 4. LIST
|
||||
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 1)
|
||||
assert.Equal(t, createdSecret.ID, secrets[0].ID)
|
||||
|
||||
// 4. LIST with values
|
||||
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secretsWithValues, 1)
|
||||
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
|
||||
|
||||
// 5. UPDATE (partial - only description)
|
||||
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
UpdateDescription: true,
|
||||
Description: "Updated workflow description",
|
||||
// 5. UPDATE
|
||||
updateParams := database.UpdateUserSecretParams{
|
||||
ID: createdSecret.ID,
|
||||
Description: "Updated workflow description",
|
||||
Value: "updated-workflow-value",
|
||||
EnvName: "UPDATED_WORKFLOW_ENV",
|
||||
FilePath: "/updated/workflow/path",
|
||||
}
|
||||
|
||||
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
|
||||
updatedSecret, err := db.UpdateUserSecret(ctx, updateParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
|
||||
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
|
||||
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
||||
assert.Equal(t, "updated-workflow-value", updatedSecret.Value)
|
||||
|
||||
// 6. DELETE
|
||||
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
})
|
||||
err = db.DeleteUserSecret(ctx, createdSecret.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
_, err = db.GetUserSecret(ctx, createdSecret.ID)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no rows in result set")
|
||||
|
||||
@@ -7453,13 +7449,9 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
})
|
||||
|
||||
// Verify both secrets exist
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret1.Name,
|
||||
})
|
||||
_, err = db.GetUserSecret(ctx, secret1.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret2.Name,
|
||||
})
|
||||
_, err = db.GetUserSecret(ctx, secret2.ID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -7482,14 +7474,14 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
// Create secrets for users
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
user1Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user1.ID,
|
||||
Name: "user1-secret",
|
||||
Description: "User 1's secret",
|
||||
Value: "user1-value",
|
||||
})
|
||||
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
user2Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user2.ID,
|
||||
Name: "user2-secret",
|
||||
Description: "User 2's secret",
|
||||
@@ -7499,8 +7491,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
subject rbac.Subject
|
||||
lookupUserID uuid.UUID
|
||||
lookupName string
|
||||
secretID uuid.UUID
|
||||
expectedAccess bool
|
||||
}{
|
||||
{
|
||||
@@ -7510,8 +7501,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
secretID: user1Secret.ID,
|
||||
expectedAccess: true,
|
||||
},
|
||||
{
|
||||
@@ -7521,8 +7511,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
lookupUserID: user2.ID,
|
||||
lookupName: "user2-secret",
|
||||
secretID: user2Secret.ID,
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7532,8 +7521,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
secretID: user1Secret.ID,
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7543,8 +7531,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
secretID: user1Secret.ID,
|
||||
expectedAccess: false,
|
||||
},
|
||||
}
|
||||
@@ -7556,10 +7543,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
|
||||
authCtx := dbauthz.As(ctx, tc.subject)
|
||||
|
||||
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: tc.lookupUserID,
|
||||
Name: tc.lookupName,
|
||||
})
|
||||
// Test GetUserSecret
|
||||
_, err := authDB.GetUserSecret(authCtx, tc.secretID)
|
||||
|
||||
if tc.expectedAccess {
|
||||
require.NoError(t, err, "expected access to be granted")
|
||||
|
||||
+259
-362
@@ -2275,105 +2275,93 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
|
||||
}
|
||||
|
||||
const countAuditLogs = `-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF($13::int, 0) + 1
|
||||
) AS limited_count
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
`
|
||||
|
||||
type CountAuditLogsParams struct {
|
||||
@@ -2389,7 +2377,6 @@ type CountAuditLogsParams struct {
|
||||
DateTo time.Time `db:"date_to" json:"date_to"`
|
||||
BuildReason string `db:"build_reason" json:"build_reason"`
|
||||
RequestID uuid.UUID `db:"request_id" json:"request_id"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) {
|
||||
@@ -2406,7 +2393,6 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
@@ -6615,49 +6601,30 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatHeartbeats = `-- name: UpdateChatHeartbeats :many
|
||||
const updateChatHeartbeat = `-- name: UpdateChatHeartbeat :execrows
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = $1::timestamptz
|
||||
heartbeat_at = NOW()
|
||||
WHERE
|
||||
id = ANY($2::uuid[])
|
||||
AND worker_id = $3::uuid
|
||||
id = $1::uuid
|
||||
AND worker_id = $2::uuid
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
type UpdateChatHeartbeatsParams struct {
|
||||
Now time.Time `db:"now" json:"now"`
|
||||
IDs []uuid.UUID `db:"ids" json:"ids"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
type UpdateChatHeartbeatParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
}
|
||||
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
func (q *sqlQuerier) UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats, arg.Now, pq.Array(arg.IDs), arg.WorkerID)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) {
|
||||
result, err := q.db.ExecContext(ctx, updateChatHeartbeat, arg.ID, arg.WorkerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
const updateChatLabelsByID = `-- name: UpdateChatLabelsByID :one
|
||||
@@ -7604,113 +7571,110 @@ func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUps
|
||||
}
|
||||
|
||||
const countConnectionLogs = `-- name: CountConnectionLogs :one
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF($14::int, 0) + 1
|
||||
) AS limited_count
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
`
|
||||
|
||||
type CountConnectionLogsParams struct {
|
||||
@@ -7727,7 +7691,6 @@ type CountConnectionLogsParams struct {
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"`
|
||||
Status string `db:"status" json:"status"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) {
|
||||
@@ -7745,7 +7708,6 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
@@ -22639,30 +22601,21 @@ INSERT INTO user_secrets (
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
value_key_id,
|
||||
env_name,
|
||||
file_path
|
||||
) VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8
|
||||
$1, $2, $3, $4, $5, $6, $7
|
||||
) RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
`
|
||||
|
||||
type CreateUserSecretParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretParams) (UserSecret, error) {
|
||||
@@ -22672,7 +22625,6 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
|
||||
arg.Name,
|
||||
arg.Description,
|
||||
arg.Value,
|
||||
arg.ValueKeyID,
|
||||
arg.EnvName,
|
||||
arg.FilePath,
|
||||
)
|
||||
@@ -22692,24 +22644,41 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
const deleteUserSecret = `-- name: DeleteUserSecret :exec
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type DeleteUserSecretByUserIDAndNameParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
|
||||
func (q *sqlQuerier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserSecret, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getUserSecret = `-- name: GetUserSecret :one
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserSecret, id)
|
||||
var i UserSecret
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.Value,
|
||||
&i.EnvName,
|
||||
&i.FilePath,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ValueKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserSecretByUserIDAndName = `-- name: GetUserSecretByUserIDAndName :one
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
FROM user_secrets
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2
|
||||
`
|
||||
|
||||
@@ -22737,76 +22706,17 @@ func (q *sqlQuerier) GetUserSecretByUserIDAndName(ctx context.Context, arg GetUs
|
||||
}
|
||||
|
||||
const listUserSecrets = `-- name: ListUserSecrets :many
|
||||
SELECT
|
||||
id, user_id, name, description,
|
||||
env_name, file_path,
|
||||
created_at, updated_at
|
||||
FROM user_secrets
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
ORDER BY name ASC
|
||||
`
|
||||
|
||||
type ListUserSecretsRow struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
// Returns metadata only (no value or value_key_id) for the
|
||||
// REST API list and get endpoints.
|
||||
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error) {
|
||||
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listUserSecrets, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListUserSecretsRow
|
||||
for rows.Next() {
|
||||
var i ListUserSecretsRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.EnvName,
|
||||
&i.FilePath,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listUserSecretsWithValues = `-- name: ListUserSecretsWithValues :many
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
ORDER BY name ASC
|
||||
`
|
||||
|
||||
// Returns all columns including the secret value. Used by the
|
||||
// provisioner (build-time injection) and the agent manifest
|
||||
// (runtime injection).
|
||||
func (q *sqlQuerier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listUserSecretsWithValues, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []UserSecret
|
||||
for rows.Next() {
|
||||
var i UserSecret
|
||||
@@ -22835,46 +22745,33 @@ func (q *sqlQuerier) ListUserSecretsWithValues(ctx context.Context, userID uuid.
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateUserSecretByUserIDAndName = `-- name: UpdateUserSecretByUserIDAndName :one
|
||||
const updateUserSecret = `-- name: UpdateUserSecret :one
|
||||
UPDATE user_secrets
|
||||
SET
|
||||
value = CASE WHEN $1::bool THEN $2 ELSE value END,
|
||||
value_key_id = CASE WHEN $1::bool THEN $3 ELSE value_key_id END,
|
||||
description = CASE WHEN $4::bool THEN $5 ELSE description END,
|
||||
env_name = CASE WHEN $6::bool THEN $7 ELSE env_name END,
|
||||
file_path = CASE WHEN $8::bool THEN $9 ELSE file_path END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = $10 AND name = $11
|
||||
description = $2,
|
||||
value = $3,
|
||||
env_name = $4,
|
||||
file_path = $5,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $1
|
||||
RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
`
|
||||
|
||||
type UpdateUserSecretByUserIDAndNameParams struct {
|
||||
UpdateValue bool `db:"update_value" json:"update_value"`
|
||||
Value string `db:"value" json:"value"`
|
||||
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
|
||||
UpdateDescription bool `db:"update_description" json:"update_description"`
|
||||
Description string `db:"description" json:"description"`
|
||||
UpdateEnvName bool `db:"update_env_name" json:"update_env_name"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
UpdateFilePath bool `db:"update_file_path" json:"update_file_path"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
type UpdateUserSecretParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserSecretByUserIDAndName,
|
||||
arg.UpdateValue,
|
||||
arg.Value,
|
||||
arg.ValueKeyID,
|
||||
arg.UpdateDescription,
|
||||
func (q *sqlQuerier) UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserSecret,
|
||||
arg.ID,
|
||||
arg.Description,
|
||||
arg.UpdateEnvName,
|
||||
arg.Value,
|
||||
arg.EnvName,
|
||||
arg.UpdateFilePath,
|
||||
arg.FilePath,
|
||||
arg.UserID,
|
||||
arg.Name,
|
||||
)
|
||||
var i UserSecret
|
||||
err := row.Scan(
|
||||
|
||||
@@ -149,105 +149,94 @@ VALUES (
|
||||
RETURNING *;
|
||||
|
||||
-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
|
||||
-- name: DeleteOldAuditLogConnectionEvents :exec
|
||||
DELETE FROM audit_logs
|
||||
|
||||
@@ -674,20 +674,17 @@ WHERE
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
-- provided they are still running and owned by the specified
|
||||
-- worker. Returns the IDs that were actually updated so the
|
||||
-- caller can detect stolen or completed chats via set-difference.
|
||||
-- name: UpdateChatHeartbeat :execrows
|
||||
-- Bumps the heartbeat timestamp for a running chat so that other
|
||||
-- replicas know the worker is still alive.
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = @now::timestamptz
|
||||
heartbeat_at = NOW()
|
||||
WHERE
|
||||
id = ANY(@ids::uuid[])
|
||||
id = @id::uuid
|
||||
AND worker_id = @worker_id::uuid
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id;
|
||||
AND status = 'running'::chat_status;
|
||||
|
||||
-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
|
||||
@@ -133,113 +133,111 @@ OFFSET
|
||||
@offset_opt;
|
||||
|
||||
-- name: CountConnectionLogs :one
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
|
||||
-- name: DeleteOldConnectionLogs :execrows
|
||||
WITH old_logs AS (
|
||||
|
||||
@@ -1,26 +1,14 @@
|
||||
-- name: GetUserSecretByUserIDAndName :one
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2;
|
||||
|
||||
-- name: GetUserSecret :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: ListUserSecrets :many
|
||||
-- Returns metadata only (no value or value_key_id) for the
|
||||
-- REST API list and get endpoints.
|
||||
SELECT
|
||||
id, user_id, name, description,
|
||||
env_name, file_path,
|
||||
created_at, updated_at
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: ListUserSecretsWithValues :many
|
||||
-- Returns all columns including the secret value. Used by the
|
||||
-- provisioner (build-time injection) and the agent manifest
|
||||
-- (runtime injection).
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: CreateUserSecret :one
|
||||
@@ -30,32 +18,23 @@ INSERT INTO user_secrets (
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
value_key_id,
|
||||
env_name,
|
||||
file_path
|
||||
) VALUES (
|
||||
@id,
|
||||
@user_id,
|
||||
@name,
|
||||
@description,
|
||||
@value,
|
||||
@value_key_id,
|
||||
@env_name,
|
||||
@file_path
|
||||
$1, $2, $3, $4, $5, $6, $7
|
||||
) RETURNING *;
|
||||
|
||||
-- name: UpdateUserSecretByUserIDAndName :one
|
||||
-- name: UpdateUserSecret :one
|
||||
UPDATE user_secrets
|
||||
SET
|
||||
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
|
||||
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
|
||||
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
|
||||
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
|
||||
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
description = $2,
|
||||
value = $3,
|
||||
env_name = $4,
|
||||
file_path = $5,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $1
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
-- name: DeleteUserSecret :exec
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
WHERE id = $1;
|
||||
|
||||
@@ -298,40 +298,6 @@ neq(input.object.owner, "");
|
||||
ExpectedSQL: p("'' = 'org-id'"),
|
||||
VariableConverter: regosql.ChatConverter(),
|
||||
},
|
||||
{
|
||||
Name: "AuditLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.AuditLogConverter(),
|
||||
},
|
||||
{
|
||||
Name: "ConnectionLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.ConnectionLogConverter(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
|
||||
func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
// Audit logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
func ConnectionLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
// Connection logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ VariableMatcher = astUUIDVar{}
|
||||
_ Node = astUUIDVar{}
|
||||
_ SupportsEquality = astUUIDVar{}
|
||||
)
|
||||
|
||||
// astUUIDVar is a variable that represents a UUID column. Unlike
|
||||
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
|
||||
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
|
||||
// This allows PostgreSQL to use indexes on UUID columns.
|
||||
type astUUIDVar struct {
|
||||
Source RegoSource
|
||||
FieldPath []string
|
||||
ColumnString string
|
||||
}
|
||||
|
||||
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
|
||||
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
|
||||
}
|
||||
|
||||
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
|
||||
|
||||
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||
left, err := RegoVarPath(u.FieldPath, rego)
|
||||
if err == nil && len(left) == 0 {
|
||||
return astUUIDVar{
|
||||
Source: RegoSource(rego.String()),
|
||||
FieldPath: u.FieldPath,
|
||||
ColumnString: u.ColumnString,
|
||||
}, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
|
||||
return u.ColumnString
|
||||
}
|
||||
|
||||
// EqualsSQLString handles equality comparisons for UUID columns.
|
||||
// Rego always produces string literals, so we accept AstString and
|
||||
// cast the literal to ::uuid in the output SQL. This lets PG use
|
||||
// native UUID indexes instead of falling back to text comparisons.
|
||||
// nolint:revive
|
||||
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case AstString:
|
||||
// The other side is a rego string literal like
|
||||
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
|
||||
// that casts the literal to uuid so PG can use indexes:
|
||||
// column = 'val'::uuid
|
||||
// instead of the text-based:
|
||||
// 'val' = COALESCE(column::text, '')
|
||||
s, ok := other.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString, got %T", other)
|
||||
}
|
||||
if s.Value == "" {
|
||||
// Empty string in rego means "no value". Compare the
|
||||
// column against NULL since UUID columns represent
|
||||
// absent values as NULL, not empty strings.
|
||||
op := "IS NULL"
|
||||
if not {
|
||||
op = "IS NOT NULL"
|
||||
}
|
||||
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
|
||||
}
|
||||
return fmt.Sprintf("%s %s '%s'::uuid",
|
||||
u.ColumnString, equalsOp(not), s.Value), nil
|
||||
case astUUIDVar:
|
||||
return basicSQLEquality(cfg, not, u, other), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T",
|
||||
u, equalsOp(not), other)
|
||||
}
|
||||
}
|
||||
|
||||
// ContainedInSQL implements SupportsContainedIn so that a UUID column
|
||||
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
|
||||
// array elements are rego strings, so we cast each to ::uuid.
|
||||
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
|
||||
arr, ok := haystack.(ASTArray)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
|
||||
}
|
||||
|
||||
if len(arr.Value) == 0 {
|
||||
return "false", nil
|
||||
}
|
||||
|
||||
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
|
||||
values := make([]string, 0, len(arr.Value))
|
||||
for _, v := range arr.Value {
|
||||
s, ok := v.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString array element, got %T", v)
|
||||
}
|
||||
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
|
||||
u.ColumnString,
|
||||
strings.Join(values, ",")), nil
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G
|
||||
}
|
||||
|
||||
// Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams.
|
||||
// nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters.
|
||||
// nolint:exhaustruct // UserID is not obtained from the query parameters.
|
||||
countFilter := database.CountAuditLogsParams{
|
||||
RequestID: filter.RequestID,
|
||||
ResourceID: filter.ResourceID,
|
||||
@@ -123,7 +123,6 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey
|
||||
}
|
||||
|
||||
// This MUST be kept in sync with the above
|
||||
// nolint:exhaustruct // CountCap is not obtained from the query parameters.
|
||||
countFilter := database.CountConnectionLogsParams{
|
||||
OrganizationID: filter.OrganizationID,
|
||||
WorkspaceOwner: filter.WorkspaceOwner,
|
||||
|
||||
+17
-35
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -390,7 +389,6 @@ type MultiAgentController struct {
|
||||
// connections to the destination
|
||||
tickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
coordination *tailnet.BasicCoordination
|
||||
sendGroup singleflight.Group
|
||||
|
||||
cancel context.CancelFunc
|
||||
expireOldAgentsDone chan struct{}
|
||||
@@ -420,44 +418,28 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo
|
||||
|
||||
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.connectionTimes[agentID]
|
||||
if ok {
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
|
||||
_, err, _ := m.sendGroup.Do(agentID.String(), func() (interface{}, error) {
|
||||
m.mu.Lock()
|
||||
coord := m.coordination
|
||||
m.mu.Unlock()
|
||||
if coord == nil {
|
||||
return nil, xerrors.New("no active coordination")
|
||||
// If we don't have the agent, subscribe.
|
||||
if !ok {
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
if m.coordination != nil {
|
||||
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
err = xerrors.Errorf("subscribe agent: %w", err)
|
||||
m.coordination.SendErr(err)
|
||||
_ = m.coordination.Client.Close()
|
||||
m.coordination = nil
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := coord.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
m.mu.Unlock()
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Error(context.Background(), "ensureAgent send failed",
|
||||
slog.F("agent_id", agentID), slog.Error(err))
|
||||
return xerrors.Errorf("send AddTunnel: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -730,10 +730,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log := s.Logger.With(
|
||||
slog.F("agent_id", appToken.AgentID),
|
||||
slog.F("workspace_id", appToken.WorkspaceID),
|
||||
)
|
||||
log := s.Logger.With(slog.F("agent_id", appToken.AgentID))
|
||||
log.Debug(ctx, "resolved PTY request")
|
||||
|
||||
values := r.URL.Query()
|
||||
@@ -768,21 +765,19 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
go httpapi.HeartbeatClose(ctx, s.Logger, cancel, conn)
|
||||
|
||||
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, log, cancel, conn)
|
||||
|
||||
dialStart := time.Now()
|
||||
|
||||
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
|
||||
return
|
||||
}
|
||||
defer release()
|
||||
log.Debug(ctx, "dialed workspace agent", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "dialed workspace agent")
|
||||
// #nosec G115 - Safe conversion for terminal height/width which are expected to be within uint16 range (0-65535)
|
||||
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
|
||||
arp.Container = container
|
||||
@@ -790,12 +785,12 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
arp.BackendType = backendType
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
|
||||
return
|
||||
}
|
||||
defer ptNetConn.Close()
|
||||
log.Debug(ctx, "obtained PTY", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "obtained PTY")
|
||||
|
||||
report := newStatsReportFromSignedToken(*appToken)
|
||||
s.collectStats(report)
|
||||
@@ -805,7 +800,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
|
||||
agentssh.Bicopy(ctx, wsNetConn, ptNetConn)
|
||||
log.Debug(ctx, "pty Bicopy finished", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "pty Bicopy finished")
|
||||
}
|
||||
|
||||
func (s *Server) collectStats(stats StatsReport) {
|
||||
|
||||
+28
-124
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -152,12 +151,6 @@ type Server struct {
|
||||
inFlightChatStaleAfter time.Duration
|
||||
chatHeartbeatInterval time.Duration
|
||||
|
||||
// heartbeatMu guards heartbeatRegistry.
|
||||
heartbeatMu sync.Mutex
|
||||
// heartbeatRegistry maps chat IDs to their cancel functions
|
||||
// and workspace state for the centralized heartbeat loop.
|
||||
heartbeatRegistry map[uuid.UUID]*heartbeatEntry
|
||||
|
||||
// wakeCh is signaled by SendMessage, EditMessage, CreateChat,
|
||||
// and PromoteQueued so the run loop calls processOnce
|
||||
// immediately instead of waiting for the next ticker.
|
||||
@@ -713,17 +706,6 @@ type chatStreamState struct {
|
||||
bufferRetainedAt time.Time
|
||||
}
|
||||
|
||||
// heartbeatEntry tracks a single chat's cancel function and workspace
|
||||
// state for the centralized heartbeat loop. Instead of spawning a
|
||||
// per-chat goroutine, processChat registers an entry here and the
|
||||
// single heartbeatLoop goroutine handles all chats.
|
||||
type heartbeatEntry struct {
|
||||
cancelWithCause context.CancelCauseFunc
|
||||
chatID uuid.UUID
|
||||
workspaceID uuid.NullUUID
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
// resetDropCounters zeroes the rate-limiting state for both buffer
|
||||
// and subscriber drop warnings. The caller must hold s.mu.
|
||||
func (s *chatStreamState) resetDropCounters() {
|
||||
@@ -2438,8 +2420,8 @@ func New(cfg Config) *Server {
|
||||
clock: clk,
|
||||
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
|
||||
wakeCh: make(chan struct{}, 1),
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
//nolint:gocritic // The chat processor uses a scoped chatd context.
|
||||
ctx = dbauthz.AsChatd(ctx)
|
||||
|
||||
@@ -2479,9 +2461,6 @@ func (p *Server) start(ctx context.Context) {
|
||||
// to handle chats orphaned by crashed or redeployed workers.
|
||||
p.recoverStaleChats(ctx)
|
||||
|
||||
// Single heartbeat loop for all chats on this replica.
|
||||
go p.heartbeatLoop(ctx)
|
||||
|
||||
acquireTicker := p.clock.NewTicker(
|
||||
p.pendingChatAcquireInterval,
|
||||
"chatd",
|
||||
@@ -2751,97 +2730,6 @@ func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) {
|
||||
p.workspaceMCPToolsCache.Delete(chatID)
|
||||
}
|
||||
|
||||
// registerHeartbeat enrolls a chat in the centralized batch
|
||||
// heartbeat loop. Must be called after chatCtx is created.
|
||||
func (p *Server) registerHeartbeat(entry *heartbeatEntry) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
if _, exists := p.heartbeatRegistry[entry.chatID]; exists {
|
||||
p.logger.Warn(context.Background(),
|
||||
"duplicate heartbeat registration, skipping",
|
||||
slog.F("chat_id", entry.chatID))
|
||||
return
|
||||
}
|
||||
p.heartbeatRegistry[entry.chatID] = entry
|
||||
}
|
||||
|
||||
// unregisterHeartbeat removes a chat from the centralized
|
||||
// heartbeat loop when chat processing finishes.
|
||||
func (p *Server) unregisterHeartbeat(chatID uuid.UUID) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
delete(p.heartbeatRegistry, chatID)
|
||||
}
|
||||
|
||||
// heartbeatLoop runs in a single goroutine, issuing one batch
|
||||
// heartbeat query per interval for all registered chats.
|
||||
func (p *Server) heartbeatLoop(ctx context.Context) {
|
||||
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "batch-heartbeat")
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.heartbeatTick(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatTick issues a single batch UPDATE for all running chats
|
||||
// owned by this worker. Chats missing from the result set are
|
||||
// interrupted (stolen by another replica or already completed).
|
||||
func (p *Server) heartbeatTick(ctx context.Context) {
|
||||
// Snapshot the registry under the lock.
|
||||
p.heartbeatMu.Lock()
|
||||
snapshot := maps.Clone(p.heartbeatRegistry)
|
||||
p.heartbeatMu.Unlock()
|
||||
|
||||
if len(snapshot) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect the IDs we believe we own.
|
||||
ids := slices.Collect(maps.Keys(snapshot))
|
||||
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
|
||||
// access for batch-updating heartbeats.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: ids,
|
||||
WorkerID: p.workerID,
|
||||
Now: p.clock.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set of IDs that were successfully updated.
|
||||
updated := make(map[uuid.UUID]struct{}, len(updatedIDs))
|
||||
for _, id := range updatedIDs {
|
||||
updated[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Interrupt registered chats that were not in the result
|
||||
// (stolen by another replica or already completed).
|
||||
for id, entry := range snapshot {
|
||||
if _, ok := updated[id]; !ok {
|
||||
entry.logger.Warn(ctx, "chat not in batch heartbeat result, interrupting")
|
||||
entry.cancelWithCause(chatloop.ErrInterrupted)
|
||||
continue
|
||||
}
|
||||
// Bump workspace usage for surviving chats.
|
||||
newWsID := p.trackWorkspaceUsage(ctx, entry.chatID, entry.workspaceID, entry.logger)
|
||||
// Update workspace ID in the registry for next tick.
|
||||
p.heartbeatMu.Lock()
|
||||
if current, exists := p.heartbeatRegistry[id]; exists {
|
||||
current.workspaceID = newWsID
|
||||
}
|
||||
p.heartbeatMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) Subscribe(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
@@ -3687,17 +3575,33 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Register with the centralized heartbeat loop instead of
|
||||
// running a per-chat goroutine. The loop issues a single batch
|
||||
// UPDATE for all chats on this worker and detects stolen chats
|
||||
// via set-difference.
|
||||
p.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chat.ID,
|
||||
workspaceID: chat.WorkspaceID,
|
||||
logger: logger,
|
||||
})
|
||||
defer p.unregisterHeartbeat(chat.ID)
|
||||
// Periodically update the heartbeat so other replicas know this
|
||||
// worker is still alive. The goroutine stops when chatCtx is
|
||||
// canceled (either by completion or interruption).
|
||||
go func() {
|
||||
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "heartbeat")
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-chatCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rows, err := p.db.UpdateChatHeartbeat(chatCtx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: p.workerID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(chatCtx, "failed to update chat heartbeat", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if rows == 0 {
|
||||
cancel(chatloop.ErrInterrupted)
|
||||
return
|
||||
}
|
||||
chat.WorkspaceID = p.trackWorkspaceUsage(chatCtx, chat.ID, chat.WorkspaceID, logger)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start buffering stream events BEFORE publishing the running
|
||||
// status. This closes a race where a subscriber sees
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
@@ -2072,7 +2071,6 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
workerID: workerID,
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
configCache: newChatConfigCache(ctx, db, clock),
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
// Publish a stale "pending" notification on the control channel
|
||||
@@ -2135,130 +2133,3 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
require.Equal(t, database.ChatStatusError, finalStatus,
|
||||
"processChat should have reached runChat (error), not been interrupted (waiting)")
|
||||
}
|
||||
|
||||
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
|
||||
// batch heartbeat UPDATE does not return a registered chat's ID
|
||||
// (because another replica stole it or it was completed), the
|
||||
// heartbeat tick cancels that chat's context with ErrInterrupted
|
||||
// while leaving surviving chats untouched.
|
||||
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
workerID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
clock: clock,
|
||||
workerID: workerID,
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
// Create three chats with independent cancel functions.
|
||||
chat1 := uuid.New()
|
||||
chat2 := uuid.New()
|
||||
chat3 := uuid.New()
|
||||
|
||||
_, cancel1 := context.WithCancelCause(ctx)
|
||||
_, cancel2 := context.WithCancelCause(ctx)
|
||||
ctx3, cancel3 := context.WithCancelCause(ctx)
|
||||
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel1,
|
||||
chatID: chat1,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel2,
|
||||
chatID: chat2,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel3,
|
||||
chatID: chat3,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
// The batch UPDATE returns only chat1 and chat2 —
|
||||
// chat3 was "stolen" by another replica.
|
||||
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
require.Equal(t, workerID, params.WorkerID)
|
||||
require.Len(t, params.IDs, 3)
|
||||
// Return only chat1 and chat2 as surviving.
|
||||
return []uuid.UUID{chat1, chat2}, nil
|
||||
},
|
||||
)
|
||||
|
||||
server.heartbeatTick(ctx)
|
||||
|
||||
// chat3's context should be canceled with ErrInterrupted.
|
||||
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
|
||||
"stolen chat should be interrupted")
|
||||
|
||||
// chat3 should have been removed from the registry by
|
||||
// unregister (in production this happens via defer in
|
||||
// processChat). The heartbeat tick itself does not
|
||||
// unregister — it only cancels. Verify the entry is
|
||||
// still present (processChat's defer would clean it up).
|
||||
server.heartbeatMu.Lock()
|
||||
_, chat1Exists := server.heartbeatRegistry[chat1]
|
||||
_, chat2Exists := server.heartbeatRegistry[chat2]
|
||||
_, chat3Exists := server.heartbeatRegistry[chat3]
|
||||
server.heartbeatMu.Unlock()
|
||||
|
||||
require.True(t, chat1Exists, "surviving chat1 should remain registered")
|
||||
require.True(t, chat2Exists, "surviving chat2 should remain registered")
|
||||
require.True(t, chat3Exists,
|
||||
"stolen chat3 should still be in registry (processChat defer removes it)")
|
||||
}
|
||||
|
||||
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
|
||||
// transient database failure causes the tick to log and return
|
||||
// without canceling any registered chats.
|
||||
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
clock: clock,
|
||||
workerID: uuid.New(),
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
chatID := uuid.New()
|
||||
chatCtx, cancel := context.WithCancelCause(ctx)
|
||||
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chatID,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
// Simulate a transient DB error.
|
||||
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
|
||||
nil, xerrors.New("connection reset"),
|
||||
)
|
||||
|
||||
server.heartbeatTick(ctx)
|
||||
|
||||
// Chat should NOT be interrupted — the tick logged and
|
||||
// returned early.
|
||||
require.NoError(t, chatCtx.Err(),
|
||||
"chat context should not be canceled on transient DB error")
|
||||
}
|
||||
|
||||
@@ -474,7 +474,7 @@ func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
|
||||
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
|
||||
}
|
||||
|
||||
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
|
||||
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -501,24 +501,19 @@ func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wrong worker_id should return no IDs.
|
||||
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, ids)
|
||||
require.Equal(t, int64(0), rows)
|
||||
|
||||
// Correct worker_id should return the chat's ID.
|
||||
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: workerID,
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, ids, 1)
|
||||
require.Equal(t, chat.ID, ids[0])
|
||||
require.Equal(t, int64(1), rows)
|
||||
}
|
||||
|
||||
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -50,11 +49,10 @@ const connectTimeout = 10 * time.Second
|
||||
const toolCallTimeout = 60 * time.Second
|
||||
|
||||
// ConnectAll connects to all configured MCP servers, discovers
|
||||
// their tools, and returns them as fantasy.AgentTool values.
|
||||
// Tools are sorted by their prefixed name so callers
|
||||
// receive a deterministic order. It skips servers that fail to
|
||||
// connect and logs warnings. The returned cleanup function
|
||||
// must be called to close all connections.
|
||||
// their tools, and returns them as fantasy.AgentTool values. It
|
||||
// skips servers that fail to connect and logs warnings. The
|
||||
// returned cleanup function must be called to close all
|
||||
// connections.
|
||||
func ConnectAll(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
@@ -110,9 +108,7 @@ func ConnectAll(
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
if mcpClient != nil {
|
||||
clients = append(clients, mcpClient)
|
||||
}
|
||||
clients = append(clients, mcpClient)
|
||||
tools = append(tools, serverTools...)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
@@ -123,31 +119,6 @@ func ConnectAll(
|
||||
// discarded.
|
||||
_ = eg.Wait()
|
||||
|
||||
// Sort tools by prefixed name for deterministic ordering
|
||||
// regardless of goroutine completion order. Ties, possible
|
||||
// when the __ separator produces ambiguous prefixed names,
|
||||
// are broken by config ID. Stable prompt construction
|
||||
// depends on consistent tool ordering.
|
||||
slices.SortFunc(tools, func(a, b fantasy.AgentTool) int {
|
||||
// All tools in this slice are mcpToolWrapper values
|
||||
// created by connectOne above, so these checked
|
||||
// assertions should always succeed. The config ID
|
||||
// tiebreaker resolves the __ separator ambiguity
|
||||
// documented at the top of this file.
|
||||
aTool, ok := a.(MCPToolIdentifier)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unexpected tool type %T", a))
|
||||
}
|
||||
bTool, ok := b.(MCPToolIdentifier)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unexpected tool type %T", b))
|
||||
}
|
||||
return cmp.Or(
|
||||
cmp.Compare(a.Info().Name, b.Info().Name),
|
||||
cmp.Compare(aTool.MCPServerConfigID().String(), bTool.MCPServerConfigID().String()),
|
||||
)
|
||||
})
|
||||
|
||||
return tools, cleanup
|
||||
}
|
||||
|
||||
|
||||
@@ -63,17 +63,6 @@ func greetTool() mcpserver.ServerTool {
|
||||
}
|
||||
}
|
||||
|
||||
// makeTool returns a ServerTool with the given name and a
|
||||
// no-op handler that always returns "ok".
|
||||
func makeTool(name string) mcpserver.ServerTool {
|
||||
return mcpserver.ServerTool{
|
||||
Tool: mcp.NewTool(name),
|
||||
Handler: func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// makeConfig builds a database.MCPServerConfig suitable for tests.
|
||||
func makeConfig(slug, url string) database.MCPServerConfig {
|
||||
return database.MCPServerConfig{
|
||||
@@ -209,121 +198,6 @@ func TestConnectAll_MultipleServers(t *testing.T) {
|
||||
assert.Contains(t, names, "beta__greet")
|
||||
}
|
||||
|
||||
func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("filtered", ts.URL)
|
||||
cfg.ToolAllowList = []string{"greet"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Empty(t, tools)
|
||||
assert.NotPanics(t, cleanup)
|
||||
}
|
||||
|
||||
func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AcrossServers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts1 := newTestMCPServer(t, makeTool("zebra"))
|
||||
ts2 := newTestMCPServer(t, makeTool("alpha"))
|
||||
ts3 := newTestMCPServer(t, makeTool("middle"))
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{
|
||||
makeConfig("srv3", ts3.URL),
|
||||
makeConfig("srv1", ts1.URL),
|
||||
makeConfig("srv2", ts2.URL),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 3)
|
||||
// Sorted by full prefixed name (slug__tool), so slug
|
||||
// order determines the sequence, not the tool name.
|
||||
assert.Equal(t,
|
||||
[]string{"srv1__zebra", "srv2__alpha", "srv3__middle"},
|
||||
toolNames(tools),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("WithMultiToolServer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
multi := newTestMCPServer(t, makeTool("zeta"), makeTool("beta"))
|
||||
other := newTestMCPServer(t, makeTool("gamma"))
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{
|
||||
makeConfig("zzz", multi.URL),
|
||||
makeConfig("aaa", other.URL),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 3)
|
||||
assert.Equal(t,
|
||||
[]string{"aaa__gamma", "zzz__beta", "zzz__zeta"},
|
||||
toolNames(tools),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("TiebreakByConfigID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts1 := newTestMCPServer(t, makeTool("b__z"))
|
||||
ts2 := newTestMCPServer(t, makeTool("z"))
|
||||
|
||||
// Use fixed UUIDs so the tiebreaker order is
|
||||
// predictable. Both servers produce the same prefixed
|
||||
// name, a__b__z, due to the __ separator ambiguity.
|
||||
cfg1 := makeConfig("a", ts1.URL)
|
||||
cfg1.ID = uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
cfg2 := makeConfig("a__b", ts2.URL)
|
||||
cfg2.ID = uuid.MustParse("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 2)
|
||||
assert.Equal(t, []string{"a__b__z", "a__b__z"}, toolNames(tools))
|
||||
|
||||
id0 := tools[0].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
|
||||
id1 := tools[1].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
|
||||
assert.Equal(t, cfg2.ID, id0, "lower config ID should sort first")
|
||||
assert.Equal(t, cfg1.ID, id1, "higher config ID should sort second")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectAll_AuthHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -212,7 +212,6 @@ type AuditLogsRequest struct {
|
||||
type AuditLogResponse struct {
|
||||
AuditLogs []AuditLog `json:"audit_logs"`
|
||||
Count int64 `json:"count"`
|
||||
CountCap int64 `json:"count_cap"`
|
||||
}
|
||||
|
||||
type CreateTestAuditLogRequest struct {
|
||||
|
||||
@@ -96,7 +96,6 @@ type ConnectionLogsRequest struct {
|
||||
type ConnectionLogResponse struct {
|
||||
ConnectionLogs []ConnectionLog `json:"connection_logs"`
|
||||
Count int64 `json:"count"`
|
||||
CountCap int64 `json:"count_cap"`
|
||||
}
|
||||
|
||||
func (c *Client) ConnectionLogs(ctx context.Context, req ConnectionLogsRequest) (ConnectionLogResponse, error) {
|
||||
|
||||
Generated
+1
-2
@@ -90,8 +90,7 @@ curl -X GET http://coder-server:8080/api/v2/audit?limit=0 \
|
||||
"user_agent": "string"
|
||||
}
|
||||
],
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
"count": 0
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Generated
+1
-2
@@ -291,8 +291,7 @@ curl -X GET http://coder-server:8080/api/v2/connectionlog?limit=0 \
|
||||
"workspace_owner_username": "string"
|
||||
}
|
||||
],
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
"count": 0
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Generated
+2
-6
@@ -1740,8 +1740,7 @@
|
||||
"user_agent": "string"
|
||||
}
|
||||
],
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
"count": 0
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1751,7 +1750,6 @@
|
||||
|--------------|-------------------------------------------------|----------|--------------|-------------|
|
||||
| `audit_logs` | array of [codersdk.AuditLog](#codersdkauditlog) | false | | |
|
||||
| `count` | integer | false | | |
|
||||
| `count_cap` | integer | false | | |
|
||||
|
||||
## codersdk.AuthMethod
|
||||
|
||||
@@ -2175,8 +2173,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
||||
"workspace_owner_username": "string"
|
||||
}
|
||||
],
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
"count": 0
|
||||
}
|
||||
```
|
||||
|
||||
@@ -2186,7 +2183,6 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
||||
|-------------------|-----------------------------------------------------------|----------|--------------|-------------|
|
||||
| `connection_logs` | array of [codersdk.ConnectionLog](#codersdkconnectionlog) | false | | |
|
||||
| `count` | integer | false | | |
|
||||
| `count_cap` | integer | false | | |
|
||||
|
||||
## codersdk.ConnectionLogSSHInfo
|
||||
|
||||
|
||||
@@ -16,9 +16,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// NOTE: See the auditLogCountCap note.
|
||||
const connectionLogCountCap = 2000
|
||||
|
||||
// @Summary Get connection logs
|
||||
// @ID get-connection-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -52,7 +49,6 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
// #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range
|
||||
filter.LimitOpt = int32(page.Limit)
|
||||
|
||||
countFilter.CountCap = connectionLogCountCap
|
||||
count, err := api.Database.CountConnectionLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -67,7 +63,6 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
|
||||
ConnectionLogs: []codersdk.ConnectionLog{},
|
||||
Count: 0,
|
||||
CountCap: connectionLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -85,7 +80,6 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
|
||||
ConnectionLogs: convertConnectionLogs(dblogs),
|
||||
Count: count,
|
||||
CountCap: connectionLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"golang.org/x/xerrors"
|
||||
gProto "google.golang.org/protobuf/proto"
|
||||
|
||||
@@ -34,9 +33,9 @@ const (
|
||||
eventReadyForHandshake = "tailnet_ready_for_handshake"
|
||||
HeartbeatPeriod = time.Second * 2
|
||||
MissedHeartbeats = 3
|
||||
numQuerierWorkers = 40
|
||||
numQuerierWorkers = 10
|
||||
numBinderWorkers = 10
|
||||
numTunnelerWorkers = 20
|
||||
numTunnelerWorkers = 10
|
||||
numHandshakerWorkers = 5
|
||||
dbMaxBackoff = 10 * time.Second
|
||||
cleanupPeriod = time.Hour
|
||||
@@ -771,9 +770,6 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo
|
||||
|
||||
for k := range m.sent {
|
||||
if _, ok := best[k]; !ok {
|
||||
m.logger.Debug(m.ctx, "peer no longer in best mappings, sending DISCONNECTED",
|
||||
slog.F("peer_id", k),
|
||||
)
|
||||
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
|
||||
Id: agpl.UUIDToByteSlice(k),
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
|
||||
@@ -824,8 +820,6 @@ type querier struct {
|
||||
mu sync.Mutex
|
||||
mappers map[mKey]*mapper
|
||||
healthy bool
|
||||
|
||||
resyncGroup singleflight.Group
|
||||
}
|
||||
|
||||
func newQuerier(ctx context.Context,
|
||||
@@ -964,7 +958,7 @@ func (q *querier) cleanupConn(c *connIO) {
|
||||
|
||||
// maxBatchSize is the maximum number of keys to process in a single batch
|
||||
// query.
|
||||
const maxBatchSize = 200
|
||||
const maxBatchSize = 50
|
||||
|
||||
func (q *querier) peerUpdateWorker() {
|
||||
defer q.wg.Done()
|
||||
@@ -1213,13 +1207,8 @@ func (q *querier) subscribe() {
|
||||
func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
|
||||
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||
q.logger.Warn(q.ctx, "pubsub may have dropped peer updates")
|
||||
// Schedule a full resync asynchronously so we don't block the
|
||||
// pubsub drain goroutine. Singleflight coalesces concurrent
|
||||
// resync requests.
|
||||
go q.resyncGroup.Do("resync", func() (any, error) {
|
||||
q.resyncPeerMappings()
|
||||
return nil, nil
|
||||
})
|
||||
// we need to schedule a full resync of peer mappings
|
||||
q.resyncPeerMappings()
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -1245,13 +1234,8 @@ func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
|
||||
func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
|
||||
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||
q.logger.Warn(q.ctx, "pubsub may have dropped tunnel updates")
|
||||
// Schedule a full resync asynchronously so we don't block the
|
||||
// pubsub drain goroutine. Singleflight coalesces concurrent
|
||||
// resync requests.
|
||||
go q.resyncGroup.Do("resync", func() (any, error) {
|
||||
q.resyncPeerMappings()
|
||||
return nil, nil
|
||||
})
|
||||
// we need to schedule a full resync of peer mappings
|
||||
q.resyncPeerMappings()
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -1617,10 +1601,6 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
|
||||
// the only mapping available for it. Newer mappings will take
|
||||
// precedence.
|
||||
m.kind = proto.CoordinateResponse_PeerUpdate_LOST
|
||||
h.logger.Debug(h.ctx, "mapping rewritten to LOST due to missed heartbeats",
|
||||
slog.F("peer_id", m.peer),
|
||||
slog.F("coordinator_id", m.coordinator),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,11 +76,11 @@ replace github.com/aquasecurity/trivy => github.com/coder/trivy v0.0.0-202603091
|
||||
// https://github.com/spf13/afero/pull/487
|
||||
replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696
|
||||
|
||||
// Forked from coder/fantasy (cj/go1.25 branch) which adds:
|
||||
// Forked from kylecarbs/fantasy (cj/go1.25 branch) which adds:
|
||||
// 1) Anthropic computer use + thinking effort
|
||||
// 2) Go 1.25 downgrade for Windows CI compat
|
||||
// 3) ibetitsmike/fantasy#4 — skip ephemeral replay items when store=false
|
||||
replace charm.land/fantasy => github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8
|
||||
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8
|
||||
|
||||
replace github.com/charmbracelet/anthropic-sdk-go => github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab
|
||||
|
||||
|
||||
@@ -322,8 +322,6 @@ github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41 h1:SBN/DA63+ZHwu
|
||||
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41/go.mod h1:I9ULxr64UaOSUv7hcb3nX4kowodJCVS7vt7VVJk/kW4=
|
||||
github.com/coder/clistat v1.2.1 h1:P9/10njXMyj5cWzIU5wkRsSy5LVQH49+tcGMsAgWX0w=
|
||||
github.com/coder/clistat v1.2.1/go.mod h1:m7SC0uj88eEERgvF8Kn6+w6XF21BeSr+15f7GoLAw0A=
|
||||
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:n+6v+yT1B6V4oSGPmXFh7mul1E+RzG9rnqp50Vb7M/w=
|
||||
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
|
||||
github.com/coder/flog v1.1.0 h1:kbAes1ai8fIS5OeV+QAnKBQE22ty1jRF/mcAwHpLBa4=
|
||||
github.com/coder/flog v1.1.0/go.mod h1:UQlQvrkJBvnRGo69Le8E24Tcl5SJleAAR7gYEHzAmdQ=
|
||||
github.com/coder/go-httpstat v0.0.0-20230801153223-321c88088322 h1:m0lPZjlQ7vdVpRBPKfYIFlmgevoTkBxB10wv6l2gOaU=
|
||||
@@ -815,6 +813,8 @@ github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:5UMY
|
||||
github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab/go.mod h1:hqlYqR7uPKOKfnNeicUbZp0Ps0GeYFlKYtwh5HGDCx8=
|
||||
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3 h1:Z9/bo5PSeMutpdiKYNt/TTSfGM1Ll0naj3QzYX9VxTc=
|
||||
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3/go.mod h1:BUGjjsD+ndS6eX37YgTchSEG+Jg9Jv1GiZs9sqPqztk=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:fZ0208U3B438fDSHCc/GNioPIyaFqn6eBsQTO61QtrI=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
|
||||
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae h1:xlFZNX4nnxpj9Cf6mTwD3pirXGNtBJ/6COsf9iZmsL0=
|
||||
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
|
||||
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e h1:OP0ZMFeZkUnOzTFRfpuK3m7Kp4fNvC6qN+exwj7aI4M=
|
||||
|
||||
+35
-44
@@ -68,17 +68,17 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
return xerrors.Errorf("detecting branch: %w", err)
|
||||
}
|
||||
|
||||
// Match release branches (release/X.Y). RCs are tagged
|
||||
// from main, not from release branches.
|
||||
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)$`)
|
||||
// Match standard release branches (release/2.32) and RC
|
||||
// branches (release/2.32-rc.0).
|
||||
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)(?:-rc\.(\d+))?$`)
|
||||
m := branchRe.FindStringSubmatch(currentBranch)
|
||||
if m == nil {
|
||||
warnf(w, "Current branch %q is not a release branch (release/X.Y).", currentBranch)
|
||||
warnf(w, "Current branch %q is not a release branch (release/X.Y or release/X.Y-rc.N).", currentBranch)
|
||||
branchInput, err := cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: "Enter the release branch to use (e.g. release/2.21)",
|
||||
Text: "Enter the release branch to use (e.g. release/2.21 or release/2.21-rc.0)",
|
||||
Validate: func(s string) error {
|
||||
if !branchRe.MatchString(s) {
|
||||
return xerrors.New("must be in format release/X.Y (e.g. release/2.21)")
|
||||
return xerrors.New("must be in format release/X.Y or release/X.Y-rc.N (e.g. release/2.21 or release/2.21-rc.0)")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -91,6 +91,10 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
}
|
||||
branchMajor, _ := strconv.Atoi(m[1])
|
||||
branchMinor, _ := strconv.Atoi(m[2])
|
||||
branchRC := -1 // -1 means not an RC branch.
|
||||
if m[3] != "" {
|
||||
branchRC, _ = strconv.Atoi(m[3])
|
||||
}
|
||||
successf(w, "Using release branch: %s", currentBranch)
|
||||
|
||||
// --- Fetch & sync check ---
|
||||
@@ -134,44 +138,31 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
}
|
||||
}
|
||||
|
||||
// changelogBaseRef is the git ref used as the starting point
|
||||
// for release notes generation. When a tag already exists in
|
||||
// this minor series we use it directly. For the first release
|
||||
// on a new minor no matching tag exists, so we compute the
|
||||
// merge-base with the previous minor's release branch instead.
|
||||
// This works even when that branch has no tags yet (it was
|
||||
// just cut and pushed). As a last resort we fall back to the
|
||||
// latest reachable tag from a previous minor.
|
||||
var changelogBaseRef string
|
||||
if prevVersion != nil {
|
||||
changelogBaseRef = prevVersion.String()
|
||||
} else {
|
||||
prevReleaseBranch := fmt.Sprintf("release/%d.%d", branchMajor, branchMinor-1)
|
||||
if err := gitRun("fetch", "--quiet", "origin", prevReleaseBranch); err != nil {
|
||||
warnf(w, "Could not fetch %s: %v", prevReleaseBranch, err)
|
||||
}
|
||||
if mb, mbErr := gitOutput("merge-base", "HEAD", "origin/"+prevReleaseBranch); mbErr == nil && mb != "" {
|
||||
changelogBaseRef = mb
|
||||
infof(w, "Using merge-base with %s as changelog base: %s", prevReleaseBranch, mb[:12])
|
||||
} else {
|
||||
// No previous release branch found; fall back to
|
||||
// the latest reachable tag from a previous minor.
|
||||
for _, t := range mergedTags {
|
||||
if t.Major == branchMajor && t.Minor < branchMinor {
|
||||
changelogBaseRef = t.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var suggested version
|
||||
if prevVersion == nil {
|
||||
infof(w, "No previous release tag found on this branch.")
|
||||
suggested = version{Major: branchMajor, Minor: branchMinor, Patch: 0}
|
||||
if branchRC >= 0 {
|
||||
suggested.Pre = fmt.Sprintf("rc.%d", branchRC)
|
||||
}
|
||||
} else {
|
||||
infof(w, "Previous release tag: %s", prevVersion.String())
|
||||
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
|
||||
if branchRC >= 0 {
|
||||
// On an RC branch, suggest the next RC for
|
||||
// the same base version.
|
||||
nextRC := 0
|
||||
if prevVersion.IsRC() {
|
||||
nextRC = prevVersion.rcNumber() + 1
|
||||
}
|
||||
suggested = version{
|
||||
Major: prevVersion.Major,
|
||||
Minor: prevVersion.Minor,
|
||||
Patch: prevVersion.Patch,
|
||||
Pre: fmt.Sprintf("rc.%d", nextRC),
|
||||
}
|
||||
} else {
|
||||
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintln(w)
|
||||
@@ -375,8 +366,8 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
infof(w, "Generating release notes...")
|
||||
|
||||
commitRange := "HEAD"
|
||||
if changelogBaseRef != "" {
|
||||
commitRange = changelogBaseRef + "..HEAD"
|
||||
if prevVersion != nil {
|
||||
commitRange = prevVersion.String() + "..HEAD"
|
||||
}
|
||||
|
||||
commits, err := commitLog(commitRange)
|
||||
@@ -482,16 +473,16 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
}
|
||||
if !hasContent {
|
||||
prevStr := "the beginning of time"
|
||||
if changelogBaseRef != "" {
|
||||
prevStr = changelogBaseRef
|
||||
if prevVersion != nil {
|
||||
prevStr = prevVersion.String()
|
||||
}
|
||||
fmt.Fprintf(¬es, "\n_No changes since %s._\n", prevStr)
|
||||
}
|
||||
|
||||
// Compare link.
|
||||
if changelogBaseRef != "" {
|
||||
if prevVersion != nil {
|
||||
fmt.Fprintf(¬es, "\nCompare: [`%s...%s`](https://github.com/%s/%s/compare/%s...%s)\n",
|
||||
changelogBaseRef, newVersion, owner, repo, changelogBaseRef, newVersion)
|
||||
prevVersion, newVersion, owner, repo, prevVersion, newVersion)
|
||||
}
|
||||
|
||||
// Container image.
|
||||
|
||||
Generated
-2
@@ -913,7 +913,6 @@ export interface AuditLog {
|
||||
export interface AuditLogResponse {
|
||||
readonly audit_logs: readonly AuditLog[];
|
||||
readonly count: number;
|
||||
readonly count_cap: number;
|
||||
}
|
||||
|
||||
// From codersdk/audit.go
|
||||
@@ -2270,7 +2269,6 @@ export interface ConnectionLog {
|
||||
export interface ConnectionLogResponse {
|
||||
readonly connection_logs: readonly ConnectionLog[];
|
||||
readonly count: number;
|
||||
readonly count_cap: number;
|
||||
}
|
||||
|
||||
// From codersdk/connectionlog.go
|
||||
|
||||
@@ -35,10 +35,10 @@ export const AvatarData: FC<AvatarDataProps> = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center w-full gap-3">
|
||||
<div className="flex items-center w-full min-w-0 gap-3">
|
||||
{avatar}
|
||||
|
||||
<div className="flex flex-col w-full">
|
||||
<div className="flex flex-col w-full min-w-0">
|
||||
<span className="text-sm font-semibold text-content-primary">
|
||||
{title}
|
||||
</span>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { Interpolation, Theme } from "@emotion/react";
|
||||
import type { FC, HTMLAttributes } from "react";
|
||||
import type { LogLevel } from "#/api/typesGenerated";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { MONOSPACE_FONT_FAMILY } from "#/theme/constants";
|
||||
|
||||
const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
|
||||
|
||||
@@ -16,40 +17,65 @@ type LogLineProps = {
|
||||
level: LogLevel;
|
||||
} & HTMLAttributes<HTMLPreElement>;
|
||||
|
||||
export const LogLine: FC<LogLineProps> = ({ level, className, ...props }) => {
|
||||
export const LogLine: FC<LogLineProps> = ({ level, ...divProps }) => {
|
||||
return (
|
||||
<pre
|
||||
{...props}
|
||||
className={cn(
|
||||
"logs-line",
|
||||
"m-0 break-all flex items-center h-auto",
|
||||
"text-[13px] text-content-primary font-mono",
|
||||
level === "error" &&
|
||||
"bg-surface-error text-content-error [&_.dashed-line]:bg-border-error",
|
||||
level === "debug" &&
|
||||
"bg-surface-sky text-content-sky [&_.dashed-line]:bg-border-sky",
|
||||
level === "warn" &&
|
||||
"bg-surface-warning text-content-warning [&_.dashed-line]:bg-border-warning",
|
||||
className,
|
||||
)}
|
||||
style={{
|
||||
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
|
||||
}}
|
||||
css={styles.line}
|
||||
className={`${level} ${divProps.className} logs-line`}
|
||||
{...divProps}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = ({
|
||||
className,
|
||||
...props
|
||||
}) => {
|
||||
return (
|
||||
<pre
|
||||
className={cn(
|
||||
"select-none m-0 inline-block text-content-secondary mr-6",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = (props) => {
|
||||
return <pre css={styles.prefix} {...props} />;
|
||||
};
|
||||
|
||||
const styles = {
|
||||
line: (theme) => ({
|
||||
margin: 0,
|
||||
wordBreak: "break-all",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
fontSize: 13,
|
||||
color: theme.palette.text.primary,
|
||||
fontFamily: MONOSPACE_FONT_FAMILY,
|
||||
height: "auto",
|
||||
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
|
||||
|
||||
"&.error": {
|
||||
backgroundColor: theme.roles.error.background,
|
||||
color: theme.roles.error.text,
|
||||
|
||||
"& .dashed-line": {
|
||||
backgroundColor: theme.roles.error.outline,
|
||||
},
|
||||
},
|
||||
|
||||
"&.debug": {
|
||||
backgroundColor: theme.roles.notice.background,
|
||||
color: theme.roles.notice.text,
|
||||
|
||||
"& .dashed-line": {
|
||||
backgroundColor: theme.roles.notice.outline,
|
||||
},
|
||||
},
|
||||
|
||||
"&.warn": {
|
||||
backgroundColor: theme.roles.warning.background,
|
||||
color: theme.roles.warning.text,
|
||||
|
||||
"& .dashed-line": {
|
||||
backgroundColor: theme.roles.warning.outline,
|
||||
},
|
||||
},
|
||||
}),
|
||||
|
||||
prefix: (theme) => ({
|
||||
userSelect: "none",
|
||||
margin: 0,
|
||||
display: "inline-block",
|
||||
color: theme.palette.text.secondary,
|
||||
marginRight: 24,
|
||||
}),
|
||||
} satisfies Record<string, Interpolation<Theme>>;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { Interpolation, Theme } from "@emotion/react";
|
||||
import dayjs from "dayjs";
|
||||
import type { FC } from "react";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { type Line, LogLine, LogLinePrefix } from "./LogLine";
|
||||
|
||||
export const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
|
||||
@@ -17,17 +17,7 @@ export const Logs: FC<LogsProps> = ({
|
||||
className = "",
|
||||
}) => {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"logs-container",
|
||||
"min-h-40 py-2 rounded-lg overflow-x-auto bg-surface-primary",
|
||||
"[&:not(:last-child)]:border-0",
|
||||
"[&:not(:last-child)]:border-solid",
|
||||
"[&:not(:last-child)]:border-b-border",
|
||||
"[&:not(:last-child)]:rounded-none",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div css={styles.root} className={`${className} logs-container`}>
|
||||
<div className="min-w-fit">
|
||||
{lines.map((line) => (
|
||||
<LogLine key={line.id} level={line.level}>
|
||||
@@ -43,3 +33,18 @@ export const Logs: FC<LogsProps> = ({
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const styles = {
|
||||
root: (theme) => ({
|
||||
minHeight: 156,
|
||||
padding: "8px 0",
|
||||
borderRadius: 8,
|
||||
overflowX: "auto",
|
||||
background: theme.palette.background.default,
|
||||
|
||||
"&:not(:last-child)": {
|
||||
borderBottom: `1px solid ${theme.palette.divider}`,
|
||||
borderRadius: 0,
|
||||
},
|
||||
}),
|
||||
} satisfies Record<string, Interpolation<Theme>>;
|
||||
|
||||
@@ -7,7 +7,6 @@ type PaginationHeaderProps = {
|
||||
limit: number;
|
||||
totalRecords: number | undefined;
|
||||
currentOffsetStart: number | undefined;
|
||||
countIsCapped?: boolean;
|
||||
|
||||
// Temporary escape hatch until Workspaces can be switched over to using
|
||||
// PaginationContainer
|
||||
@@ -19,7 +18,6 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
|
||||
limit,
|
||||
totalRecords,
|
||||
currentOffsetStart,
|
||||
countIsCapped,
|
||||
className,
|
||||
}) => {
|
||||
const theme = useTheme();
|
||||
@@ -54,16 +52,10 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
|
||||
<strong>
|
||||
{(
|
||||
currentOffsetStart +
|
||||
(countIsCapped
|
||||
? limit - 1
|
||||
: Math.min(limit - 1, totalRecords - currentOffsetStart))
|
||||
Math.min(limit - 1, totalRecords - currentOffsetStart)
|
||||
).toLocaleString()}
|
||||
</strong>{" "}
|
||||
of{" "}
|
||||
<strong>
|
||||
{totalRecords.toLocaleString()}
|
||||
{countIsCapped && "+"}
|
||||
</strong>{" "}
|
||||
of <strong>{totalRecords.toLocaleString()}</strong>{" "}
|
||||
{paginationUnitLabel}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -18,7 +18,6 @@ export const mockPaginationResultBase: ResultBase = {
|
||||
limit: 25,
|
||||
hasNextPage: false,
|
||||
hasPreviousPage: false,
|
||||
countIsCapped: false,
|
||||
goToPreviousPage: () => {},
|
||||
goToNextPage: () => {},
|
||||
goToFirstPage: () => {},
|
||||
@@ -34,7 +33,6 @@ export const mockInitialRenderResult: PaginationResult = {
|
||||
hasPreviousPage: false,
|
||||
totalRecords: undefined,
|
||||
totalPages: undefined,
|
||||
countIsCapped: false,
|
||||
};
|
||||
|
||||
export const mockSuccessResult: PaginationResult = {
|
||||
|
||||
@@ -94,7 +94,7 @@ export const FirstPageWithTonsOfData: Story = {
|
||||
currentPage: 2,
|
||||
currentOffsetStart: 1000,
|
||||
totalRecords: 123_456,
|
||||
totalPages: 4939,
|
||||
totalPages: 1235,
|
||||
hasPreviousPage: false,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
@@ -135,54 +135,3 @@ export const SecondPageWithData: Story = {
|
||||
children: <div>New data for page 2</div>,
|
||||
},
|
||||
};
|
||||
|
||||
export const CappedCountFirstPage: Story = {
|
||||
args: {
|
||||
query: {
|
||||
...mockPaginationResultBase,
|
||||
isSuccess: true,
|
||||
currentPage: 1,
|
||||
currentOffsetStart: 1,
|
||||
totalRecords: 2000,
|
||||
totalPages: 80,
|
||||
hasPreviousPage: false,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
countIsCapped: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CappedCountMiddlePage: Story = {
|
||||
args: {
|
||||
query: {
|
||||
...mockPaginationResultBase,
|
||||
isSuccess: true,
|
||||
currentPage: 3,
|
||||
currentOffsetStart: 51,
|
||||
totalRecords: 2000,
|
||||
totalPages: 80,
|
||||
hasPreviousPage: true,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
countIsCapped: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CappedCountBeyondKnownPages: Story = {
|
||||
args: {
|
||||
query: {
|
||||
...mockPaginationResultBase,
|
||||
isSuccess: true,
|
||||
currentPage: 85,
|
||||
currentOffsetStart: 2101,
|
||||
totalRecords: 2000,
|
||||
totalPages: 85,
|
||||
hasPreviousPage: true,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
countIsCapped: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -27,14 +27,12 @@ export const PaginationContainer: FC<PaginationProps> = ({
|
||||
totalRecords={query.totalRecords}
|
||||
currentOffsetStart={query.currentOffsetStart}
|
||||
paginationUnitLabel={paginationUnitLabel}
|
||||
countIsCapped={query.countIsCapped}
|
||||
className="justify-end"
|
||||
/>
|
||||
|
||||
{query.isSuccess && (
|
||||
<PaginationWidgetBase
|
||||
totalRecords={query.totalRecords}
|
||||
totalPages={query.totalPages}
|
||||
currentPage={query.currentPage}
|
||||
pageSize={query.limit}
|
||||
onPageChange={query.onPageChange}
|
||||
|
||||
@@ -12,10 +12,6 @@ export type PaginationWidgetBaseProps = {
|
||||
|
||||
hasPreviousPage?: boolean;
|
||||
hasNextPage?: boolean;
|
||||
/** Override the computed totalPages.
|
||||
* Used when, e.g., the row count is capped and the user navigates beyond
|
||||
* the known range, so totalPages stays at least as high as currentPage. */
|
||||
totalPages?: number;
|
||||
};
|
||||
|
||||
export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
|
||||
@@ -25,9 +21,8 @@ export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
|
||||
onPageChange,
|
||||
hasPreviousPage,
|
||||
hasNextPage,
|
||||
totalPages: totalPagesProp,
|
||||
}) => {
|
||||
const totalPages = totalPagesProp ?? Math.ceil(totalRecords / pageSize);
|
||||
const totalPages = Math.ceil(totalRecords / pageSize);
|
||||
|
||||
if (totalPages < 2) {
|
||||
return null;
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
export { useTabOverflowKebabMenu } from "./useTabOverflowKebabMenu";
|
||||
@@ -1,274 +0,0 @@
|
||||
import {
|
||||
type RefObject,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
type TabValue = {
|
||||
value: string;
|
||||
};
|
||||
|
||||
type UseKebabMenuOptions<T extends TabValue> = {
|
||||
tabs: readonly T[];
|
||||
enabled: boolean;
|
||||
isActive: boolean;
|
||||
overflowTriggerWidth?: number;
|
||||
};
|
||||
|
||||
type UseKebabMenuResult<T extends TabValue> = {
|
||||
containerRef: RefObject<HTMLDivElement | null>;
|
||||
visibleTabs: T[];
|
||||
overflowTabs: T[];
|
||||
getTabMeasureProps: (tabValue: string) => Record<string, string>;
|
||||
};
|
||||
|
||||
const ALWAYS_VISIBLE_TABS_COUNT = 1;
|
||||
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
|
||||
|
||||
/**
|
||||
* Splits tabs into visible and overflow groups based on container width.
|
||||
*
|
||||
* Tabs must render with `getTabMeasureProps()` so this hook can measure
|
||||
* trigger widths from the DOM.
|
||||
*/
|
||||
export const useKebabMenu = <T extends TabValue>({
|
||||
tabs,
|
||||
enabled,
|
||||
isActive,
|
||||
overflowTriggerWidth = 44,
|
||||
}: UseKebabMenuOptions<T>): UseKebabMenuResult<T> => {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const tabsRef = useRef<readonly T[]>(tabs);
|
||||
tabsRef.current = tabs;
|
||||
const previousTabsRef = useRef<readonly T[]>(tabs);
|
||||
const availableWidthRef = useRef<number | null>(null);
|
||||
// Width cache prevents oscillation when overflow tabs are not mounted.
|
||||
const tabWidthByValueRef = useRef<Record<string, number>>({});
|
||||
const [overflowTabValues, setTabValues] = useState<string[]>([]);
|
||||
|
||||
const recalculateOverflow = useCallback(
|
||||
(availableWidth: number) => {
|
||||
if (!enabled || !isActive) {
|
||||
// Keep this update idempotent to avoid render loops.
|
||||
setTabValues((currentValues) => {
|
||||
if (currentValues.length === 0) {
|
||||
return currentValues;
|
||||
}
|
||||
return [];
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
const currentTabs = tabsRef.current;
|
||||
|
||||
const tabWidthByValue = measureTabWidths({
|
||||
tabs: currentTabs,
|
||||
container,
|
||||
previousTabWidthByValue: tabWidthByValueRef.current,
|
||||
});
|
||||
tabWidthByValueRef.current = tabWidthByValue;
|
||||
|
||||
const nextOverflowValues = calculateTabValues({
|
||||
tabs: currentTabs,
|
||||
availableWidth,
|
||||
tabWidthByValue,
|
||||
overflowTriggerWidth,
|
||||
});
|
||||
|
||||
setTabValues((currentValues) => {
|
||||
// Avoid state updates when the computed overflow did not change.
|
||||
if (areStringArraysEqual(currentValues, nextOverflowValues)) {
|
||||
return currentValues;
|
||||
}
|
||||
return nextOverflowValues;
|
||||
});
|
||||
},
|
||||
[enabled, isActive, overflowTriggerWidth],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (previousTabsRef.current === tabs) {
|
||||
// No change in tabs, no need to recalculate.
|
||||
return;
|
||||
}
|
||||
previousTabsRef.current = tabs;
|
||||
if (availableWidthRef.current === null) {
|
||||
// First mount, no width available yet.
|
||||
return;
|
||||
}
|
||||
recalculateOverflow(availableWidthRef.current);
|
||||
}, [recalculateOverflow, tabs]);
|
||||
|
||||
useEffect(() => {
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Recompute whenever ResizeObserver reports a container width change.
|
||||
const observer = new ResizeObserver(([entry]) => {
|
||||
if (!entry) {
|
||||
return;
|
||||
}
|
||||
availableWidthRef.current = entry.contentRect.width;
|
||||
recalculateOverflow(entry.contentRect.width);
|
||||
});
|
||||
observer.observe(container);
|
||||
return () => observer.disconnect();
|
||||
}, [recalculateOverflow]);
|
||||
|
||||
const overflowTabValuesSet = new Set(overflowTabValues);
|
||||
const { visibleTabs, overflowTabs } = tabs.reduce<{
|
||||
visibleTabs: T[];
|
||||
overflowTabs: T[];
|
||||
}>(
|
||||
(tabGroups, tab) => {
|
||||
if (overflowTabValuesSet.has(tab.value)) {
|
||||
tabGroups.overflowTabs.push(tab);
|
||||
} else {
|
||||
tabGroups.visibleTabs.push(tab);
|
||||
}
|
||||
return tabGroups;
|
||||
},
|
||||
{ visibleTabs: [], overflowTabs: [] },
|
||||
);
|
||||
|
||||
const getTabMeasureProps = (tabValue: string) => {
|
||||
return { [DATA_ATTR_TAB_VALUE]: tabValue };
|
||||
};
|
||||
|
||||
return {
|
||||
containerRef,
|
||||
visibleTabs,
|
||||
overflowTabs,
|
||||
getTabMeasureProps,
|
||||
};
|
||||
};
|
||||
|
||||
const calculateTabValues = <T extends TabValue>({
|
||||
tabs,
|
||||
availableWidth,
|
||||
tabWidthByValue,
|
||||
overflowTriggerWidth,
|
||||
}: {
|
||||
tabs: readonly T[];
|
||||
availableWidth: number;
|
||||
tabWidthByValue: Readonly<Record<string, number>>;
|
||||
overflowTriggerWidth: number;
|
||||
}): string[] => {
|
||||
const tabWidthByValueMap = new Map<string, number>();
|
||||
for (const tab of tabs) {
|
||||
tabWidthByValueMap.set(tab.value, tabWidthByValue[tab.value] ?? 0);
|
||||
}
|
||||
|
||||
const firstOptionalTabIndex = Math.min(
|
||||
ALWAYS_VISIBLE_TABS_COUNT,
|
||||
tabs.length,
|
||||
);
|
||||
if (firstOptionalTabIndex >= tabs.length) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const alwaysVisibleTabs = tabs.slice(0, firstOptionalTabIndex);
|
||||
const optionalTabs = tabs.slice(firstOptionalTabIndex);
|
||||
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
|
||||
return total + (tabWidthByValueMap.get(tab.value) ?? 0);
|
||||
}, 0);
|
||||
const firstTabIndex = findFirstTabIndex({
|
||||
optionalTabs,
|
||||
optionalTabWidths: optionalTabs.map((tab) => {
|
||||
return tabWidthByValueMap.get(tab.value) ?? 0;
|
||||
}),
|
||||
startingUsedWidth: alwaysVisibleWidth,
|
||||
availableWidth,
|
||||
overflowTriggerWidth,
|
||||
});
|
||||
|
||||
if (firstTabIndex === -1) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return optionalTabs
|
||||
.slice(firstTabIndex)
|
||||
.map((overflowTab) => overflowTab.value);
|
||||
};
|
||||
|
||||
const measureTabWidths = <T extends TabValue>({
|
||||
tabs,
|
||||
container,
|
||||
previousTabWidthByValue,
|
||||
}: {
|
||||
tabs: readonly T[];
|
||||
container: HTMLDivElement;
|
||||
previousTabWidthByValue: Readonly<Record<string, number>>;
|
||||
}): Record<string, number> => {
|
||||
const nextTabWidthByValue = { ...previousTabWidthByValue };
|
||||
for (const tab of tabs) {
|
||||
const tabElement = container.querySelector<HTMLElement>(
|
||||
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
|
||||
);
|
||||
if (tabElement) {
|
||||
nextTabWidthByValue[tab.value] = tabElement.offsetWidth;
|
||||
}
|
||||
}
|
||||
return nextTabWidthByValue;
|
||||
};
|
||||
|
||||
const findFirstTabIndex = ({
|
||||
optionalTabs,
|
||||
optionalTabWidths,
|
||||
startingUsedWidth,
|
||||
availableWidth,
|
||||
overflowTriggerWidth,
|
||||
}: {
|
||||
optionalTabs: readonly TabValue[];
|
||||
optionalTabWidths: readonly number[];
|
||||
startingUsedWidth: number;
|
||||
availableWidth: number;
|
||||
overflowTriggerWidth: number;
|
||||
}): number => {
|
||||
const result = optionalTabs.reduce(
|
||||
(acc, _tab, index) => {
|
||||
if (acc.firstTabIndex !== -1) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
const tabWidth = optionalTabWidths[index] ?? 0;
|
||||
const hasMoreTabs = index < optionalTabs.length - 1;
|
||||
// Reserve kebab trigger width whenever additional tabs remain.
|
||||
const widthNeeded =
|
||||
acc.usedWidth + tabWidth + (hasMoreTabs ? overflowTriggerWidth : 0);
|
||||
|
||||
if (widthNeeded <= availableWidth) {
|
||||
return {
|
||||
usedWidth: acc.usedWidth + tabWidth,
|
||||
firstTabIndex: -1,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
usedWidth: acc.usedWidth,
|
||||
firstTabIndex: index,
|
||||
};
|
||||
},
|
||||
{ usedWidth: startingUsedWidth, firstTabIndex: -1 },
|
||||
);
|
||||
|
||||
return result.firstTabIndex;
|
||||
};
|
||||
|
||||
const areStringArraysEqual = (
|
||||
left: readonly string[],
|
||||
right: readonly string[],
|
||||
): boolean => {
|
||||
return (
|
||||
left.length === right.length &&
|
||||
left.every((value, index) => value === right[index])
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,157 @@
|
||||
import {
|
||||
type RefObject,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
type TabLike = {
|
||||
value: string;
|
||||
};
|
||||
|
||||
type UseTabOverflowKebabMenuOptions<TTab extends TabLike> = {
|
||||
tabs: readonly TTab[];
|
||||
enabled: boolean;
|
||||
isActive: boolean;
|
||||
alwaysVisibleTabsCount?: number;
|
||||
overflowTriggerWidthPx?: number;
|
||||
};
|
||||
|
||||
type UseTabOverflowKebabMenuResult<TTab extends TabLike> = {
|
||||
containerRef: RefObject<HTMLDivElement | null>;
|
||||
visibleTabs: TTab[];
|
||||
overflowTabs: TTab[];
|
||||
getTabMeasureProps: (tabValue: string) => Record<string, string>;
|
||||
};
|
||||
|
||||
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
|
||||
|
||||
export const useTabOverflowKebabMenu = <TTab extends TabLike>({
|
||||
tabs,
|
||||
enabled,
|
||||
isActive,
|
||||
alwaysVisibleTabsCount = 1,
|
||||
overflowTriggerWidthPx = 44,
|
||||
}: UseTabOverflowKebabMenuOptions<TTab>): UseTabOverflowKebabMenuResult<TTab> => {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const tabWidthByValueRef = useRef<Record<string, number>>({});
|
||||
const [overflowTabValues, setOverflowTabValues] = useState<string[]>([]);
|
||||
|
||||
const recalculateOverflow = useCallback(() => {
|
||||
if (!enabled) {
|
||||
setOverflowTabValues([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const tab of tabs) {
|
||||
const tabElement = container.querySelector<HTMLElement>(
|
||||
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
|
||||
);
|
||||
if (tabElement) {
|
||||
tabWidthByValueRef.current[tab.value] = tabElement.offsetWidth;
|
||||
}
|
||||
}
|
||||
|
||||
const alwaysVisibleTabs = tabs.slice(0, alwaysVisibleTabsCount);
|
||||
const optionalTabs = tabs.slice(alwaysVisibleTabsCount);
|
||||
if (optionalTabs.length === 0) {
|
||||
setOverflowTabValues([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
|
||||
return total + (tabWidthByValueRef.current[tab.value] ?? 0);
|
||||
}, 0);
|
||||
|
||||
const availableWidth = container.clientWidth;
|
||||
let usedWidth = alwaysVisibleWidth;
|
||||
const nextOverflowValues: string[] = [];
|
||||
|
||||
for (let i = 0; i < optionalTabs.length; i++) {
|
||||
const tab = optionalTabs[i];
|
||||
const tabWidth = tabWidthByValueRef.current[tab.value] ?? 0;
|
||||
const hasMoreTabsAfterCurrent = i < optionalTabs.length - 1;
|
||||
const widthNeeded =
|
||||
usedWidth +
|
||||
tabWidth +
|
||||
(hasMoreTabsAfterCurrent ? overflowTriggerWidthPx : 0);
|
||||
|
||||
if (widthNeeded <= availableWidth) {
|
||||
usedWidth += tabWidth;
|
||||
continue;
|
||||
}
|
||||
|
||||
nextOverflowValues.push(
|
||||
...optionalTabs.slice(i).map((overflowTab) => overflowTab.value),
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
setOverflowTabValues((currentValues) => {
|
||||
if (
|
||||
currentValues.length === nextOverflowValues.length &&
|
||||
currentValues.every(
|
||||
(value, index) => value === nextOverflowValues[index],
|
||||
)
|
||||
) {
|
||||
return currentValues;
|
||||
}
|
||||
return nextOverflowValues;
|
||||
});
|
||||
}, [alwaysVisibleTabsCount, enabled, overflowTriggerWidthPx, tabs]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (!isActive) {
|
||||
return;
|
||||
}
|
||||
recalculateOverflow();
|
||||
}, [isActive, recalculateOverflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isActive) {
|
||||
return;
|
||||
}
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
const observer = new ResizeObserver(() => {
|
||||
recalculateOverflow();
|
||||
});
|
||||
observer.observe(container);
|
||||
return () => observer.disconnect();
|
||||
}, [isActive, recalculateOverflow]);
|
||||
|
||||
const overflowTabValuesSet = useMemo(
|
||||
() => new Set(overflowTabValues),
|
||||
[overflowTabValues],
|
||||
);
|
||||
|
||||
const visibleTabs = useMemo(
|
||||
() => tabs.filter((tab) => !overflowTabValuesSet.has(tab.value)),
|
||||
[tabs, overflowTabValuesSet],
|
||||
);
|
||||
const overflowTabs = useMemo(
|
||||
() => tabs.filter((tab) => overflowTabValuesSet.has(tab.value)),
|
||||
[tabs, overflowTabValuesSet],
|
||||
);
|
||||
|
||||
const getTabMeasureProps = useCallback((tabValue: string) => {
|
||||
return { [DATA_ATTR_TAB_VALUE]: tabValue };
|
||||
}, []);
|
||||
|
||||
return {
|
||||
containerRef,
|
||||
visibleTabs,
|
||||
overflowTabs,
|
||||
getTabMeasureProps,
|
||||
};
|
||||
};
|
||||
@@ -258,78 +258,6 @@ describe(usePaginatedQuery.name, () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("Capped count behavior", () => {
|
||||
const mockQueryKey = vi.fn(() => ["mock"]);
|
||||
|
||||
// Returns count 2001 (capped) with items on pages up to page 84
|
||||
// (84 * 25 = 2100 items total).
|
||||
const mockCappedQueryFn = vi.fn(({ pageNumber, limit }) => {
|
||||
const totalItems = 2100;
|
||||
const offset = (pageNumber - 1) * limit;
|
||||
// Returns 0 items when the requested page is past the end, simulating
|
||||
// an empty server response.
|
||||
const itemsOnPage = Math.max(0, Math.min(limit, totalItems - offset));
|
||||
return Promise.resolve({
|
||||
data: new Array(itemsOnPage).fill(pageNumber),
|
||||
count: 2001,
|
||||
count_cap: 2000,
|
||||
});
|
||||
});
|
||||
|
||||
it("Caps totalRecords at 2000 when count exceeds cap", async () => {
|
||||
const { result } = await render({
|
||||
queryKey: mockQueryKey,
|
||||
queryFn: mockCappedQueryFn,
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
expect(result.current.totalRecords).toBe(2000);
|
||||
});
|
||||
|
||||
it("hasNextPage is true when count is capped", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=80",
|
||||
);
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
expect(result.current.hasNextPage).toBe(true);
|
||||
});
|
||||
|
||||
it("hasPreviousPage is true when count is capped and page is beyond cap", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=83",
|
||||
);
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
expect(result.current.hasPreviousPage).toBe(true);
|
||||
});
|
||||
|
||||
it("Does not redirect to last page when count is capped and page is valid", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=83",
|
||||
);
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
// Should stay on page 83 — not redirect to page 80.
|
||||
expect(result.current.currentPage).toBe(83);
|
||||
});
|
||||
|
||||
it("Redirects to last known page when navigating beyond actual data", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=999",
|
||||
);
|
||||
|
||||
// Page 999 has no items. Should redirect to page 81
|
||||
// (ceil(2001 / 25) = 81), the last page guaranteed to
|
||||
// have data.
|
||||
await waitFor(() => expect(result.current.currentPage).toBe(81));
|
||||
});
|
||||
});
|
||||
|
||||
describe("Passing in searchParams property", () => {
|
||||
const mockQueryKey = vi.fn(() => ["mock"]);
|
||||
const mockQueryFn = vi.fn(({ pageNumber, limit }) =>
|
||||
|
||||
@@ -144,44 +144,16 @@ export function usePaginatedQuery<
|
||||
placeholderData: keepPreviousData,
|
||||
});
|
||||
|
||||
const count = query.data?.count;
|
||||
const countCap = query.data?.count_cap;
|
||||
const countIsCapped =
|
||||
countCap !== undefined &&
|
||||
countCap > 0 &&
|
||||
count !== undefined &&
|
||||
count > countCap;
|
||||
const totalRecords = countIsCapped ? countCap : count;
|
||||
let totalPages =
|
||||
totalRecords !== undefined
|
||||
? Math.max(
|
||||
Math.ceil(totalRecords / limit),
|
||||
// True count is not known; let them navigate forward
|
||||
// until they hit an empty page (checked below).
|
||||
countIsCapped ? currentPage : 0,
|
||||
)
|
||||
: undefined;
|
||||
|
||||
// When the true count is unknown, the user can navigate past
|
||||
// all actual data. If that happens, we need to redirect (via
|
||||
// updatePageIfInvalid) to the last page guaranteed to be not
|
||||
// empty.
|
||||
const pageIsEmpty =
|
||||
query.data != null &&
|
||||
!Object.values(query.data).some((v) => Array.isArray(v) && v.length > 0);
|
||||
if (pageIsEmpty) {
|
||||
totalPages = count !== undefined ? Math.ceil(count / limit) : 1;
|
||||
}
|
||||
const totalRecords = query.data?.count;
|
||||
const totalPages =
|
||||
totalRecords !== undefined ? Math.ceil(totalRecords / limit) : undefined;
|
||||
|
||||
const hasNextPage =
|
||||
totalRecords !== undefined &&
|
||||
((countIsCapped && !pageIsEmpty) ||
|
||||
limit + currentPageOffset < totalRecords);
|
||||
totalRecords !== undefined && limit + currentPageOffset < totalRecords;
|
||||
const hasPreviousPage =
|
||||
totalRecords !== undefined &&
|
||||
currentPage > 1 &&
|
||||
((countIsCapped && !pageIsEmpty) ||
|
||||
currentPageOffset - limit < totalRecords);
|
||||
currentPageOffset - limit < totalRecords;
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const prefetchPage = useEffectEvent((newPage: number) => {
|
||||
@@ -252,14 +224,10 @@ export function usePaginatedQuery<
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
!query.isFetching &&
|
||||
totalPages !== undefined &&
|
||||
currentPage > totalPages
|
||||
) {
|
||||
if (!query.isFetching && totalPages !== undefined) {
|
||||
void updatePageIfInvalid(totalPages);
|
||||
}
|
||||
}, [updatePageIfInvalid, query.isFetching, totalPages, currentPage]);
|
||||
}, [updatePageIfInvalid, query.isFetching, totalPages]);
|
||||
|
||||
const onPageChange = (newPage: number) => {
|
||||
// Page 1 is the only page that can be safely navigated to without knowing
|
||||
@@ -268,12 +236,7 @@ export function usePaginatedQuery<
|
||||
return;
|
||||
}
|
||||
|
||||
// If the true count is unknown, we allow navigating past the
|
||||
// known page range.
|
||||
const upperBound = countIsCapped
|
||||
? Number.MAX_SAFE_INTEGER
|
||||
: (totalPages ?? 1);
|
||||
const cleanedInput = clamp(Math.trunc(newPage), 1, upperBound);
|
||||
const cleanedInput = clamp(Math.trunc(newPage), 1, totalPages ?? 1);
|
||||
if (Number.isNaN(cleanedInput)) {
|
||||
return;
|
||||
}
|
||||
@@ -311,7 +274,6 @@ export function usePaginatedQuery<
|
||||
totalRecords: totalRecords as number,
|
||||
totalPages: totalPages as number,
|
||||
currentOffsetStart: currentPageOffset + 1,
|
||||
countIsCapped,
|
||||
}
|
||||
: {
|
||||
isSuccess: false,
|
||||
@@ -320,7 +282,6 @@ export function usePaginatedQuery<
|
||||
totalRecords: undefined,
|
||||
totalPages: undefined,
|
||||
currentOffsetStart: undefined,
|
||||
countIsCapped: false as const,
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -362,7 +323,6 @@ export type PaginationResultInfo = {
|
||||
totalRecords: undefined;
|
||||
totalPages: undefined;
|
||||
currentOffsetStart: undefined;
|
||||
countIsCapped: false;
|
||||
}
|
||||
| {
|
||||
isSuccess: true;
|
||||
@@ -371,7 +331,6 @@ export type PaginationResultInfo = {
|
||||
totalRecords: number;
|
||||
totalPages: number;
|
||||
currentOffsetStart: number;
|
||||
countIsCapped: boolean;
|
||||
}
|
||||
);
|
||||
|
||||
@@ -458,7 +417,6 @@ type QueryPageParamsWithPayload<TPayload = never> = QueryPageParams & {
|
||||
*/
|
||||
export type PaginatedData = {
|
||||
count: number;
|
||||
count_cap?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { expect, spyOn, userEvent, waitFor, within } from "storybook/test";
|
||||
import { API } from "#/api/api";
|
||||
import { workspaceAgentContainersKey } from "#/api/queries/workspaces";
|
||||
import type { WorkspaceAgentLogSource } from "#/api/typesGenerated";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import { getPreferredProxy } from "#/contexts/ProxyContext";
|
||||
import { chromatic } from "#/testHelpers/chromatic";
|
||||
import * as M from "#/testHelpers/entities";
|
||||
@@ -76,8 +76,6 @@ const defaultAgentMetadata = [
|
||||
},
|
||||
];
|
||||
|
||||
const fixedLogTimestamp = "2021-05-05T00:00:00.000Z";
|
||||
|
||||
const logs = [
|
||||
"\x1b[91mCloning Git repository...",
|
||||
"\x1b[2;37;41mStarting Docker Daemon...",
|
||||
@@ -89,10 +87,10 @@ const logs = [
|
||||
level: "info",
|
||||
output: line,
|
||||
source_id: M.MockWorkspaceAgentLogSource.id,
|
||||
created_at: fixedLogTimestamp,
|
||||
created_at: new Date().toISOString(),
|
||||
}));
|
||||
|
||||
const installScriptLogSource: WorkspaceAgentLogSource = {
|
||||
const installScriptLogSource: TypesGen.WorkspaceAgentLogSource = {
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "f2ee4b8d-b09d-4f4e-a1f1-5e4adf7d53bb",
|
||||
display_name: "Install Script",
|
||||
@@ -104,24 +102,60 @@ const tabbedLogs = [
|
||||
level: "info",
|
||||
output: "startup: preparing workspace",
|
||||
source_id: M.MockWorkspaceAgentLogSource.id,
|
||||
created_at: fixedLogTimestamp,
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
{
|
||||
id: 101,
|
||||
level: "info",
|
||||
output: "install: pnpm install",
|
||||
source_id: installScriptLogSource.id,
|
||||
created_at: fixedLogTimestamp,
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
{
|
||||
id: 102,
|
||||
level: "info",
|
||||
output: "install: setup complete",
|
||||
source_id: installScriptLogSource.id,
|
||||
created_at: fixedLogTimestamp,
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
];
|
||||
|
||||
const overflowLogSources: TypesGen.WorkspaceAgentLogSource[] = [
|
||||
M.MockWorkspaceAgentLogSource,
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "58f5db69-5f78-496f-bce1-0686f5525aa1",
|
||||
display_name: "code-server",
|
||||
icon: "/icon/code.svg",
|
||||
},
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "f39d758c-bce2-4f41-8d70-58fdb1f0f729",
|
||||
display_name: "Install and start AgentAPI",
|
||||
icon: "/icon/claude.svg",
|
||||
},
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "bf7529b8-1787-4a20-b54f-eb894680e48f",
|
||||
display_name: "Mux",
|
||||
icon: "/icon/mux.svg",
|
||||
},
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "0d6ebde6-c534-4551-9f91-bfd98bfb04f4",
|
||||
display_name: "Portable Desktop",
|
||||
icon: "/icon/portable-desktop.svg",
|
||||
},
|
||||
];
|
||||
|
||||
const overflowLogs = overflowLogSources.map((source, index) => ({
|
||||
id: 200 + index,
|
||||
level: "info",
|
||||
output: `${source.display_name}: line`,
|
||||
source_id: source.id,
|
||||
created_at: new Date().toISOString(),
|
||||
}));
|
||||
|
||||
const meta: Meta<typeof AgentRow> = {
|
||||
title: "components/AgentRow",
|
||||
component: AgentRow,
|
||||
@@ -404,3 +438,44 @@ export const LogsTabs: Story = {
|
||||
await expect(canvas.getByText("install: pnpm install")).toBeVisible();
|
||||
},
|
||||
};
|
||||
|
||||
export const LogsTabsOverflow: Story = {
|
||||
args: {
|
||||
agent: {
|
||||
...M.MockWorkspaceAgentReady,
|
||||
logs_length: overflowLogs.length,
|
||||
log_sources: overflowLogSources,
|
||||
},
|
||||
},
|
||||
parameters: {
|
||||
webSocket: [
|
||||
{
|
||||
event: "message",
|
||||
data: JSON.stringify(overflowLogs),
|
||||
},
|
||||
],
|
||||
},
|
||||
render: (args) => (
|
||||
<div className="max-w-[320px]">
|
||||
<AgentRow {...args} />
|
||||
</div>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const page = within(canvasElement.ownerDocument.body);
|
||||
await userEvent.click(canvas.getByRole("button", { name: "Logs" }));
|
||||
await userEvent.click(
|
||||
canvas.getByRole("button", { name: "More log tabs" }),
|
||||
);
|
||||
const overflowItems = await page.findAllByRole("menuitemradio");
|
||||
const selectedItem = overflowItems[0];
|
||||
const selectedSource = selectedItem.textContent;
|
||||
if (!selectedSource) {
|
||||
throw new Error("Overflow menu item must have text content.");
|
||||
}
|
||||
await userEvent.click(selectedItem);
|
||||
await waitFor(() =>
|
||||
expect(canvas.getByText(`${selectedSource}: line`)).toBeVisible(),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
@@ -8,9 +8,10 @@ import {
|
||||
} from "lucide-react";
|
||||
import {
|
||||
type FC,
|
||||
type ReactNode,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
@@ -41,7 +42,7 @@ import {
|
||||
TabsList,
|
||||
TabsTrigger,
|
||||
} from "#/components/Tabs/Tabs";
|
||||
import { useKebabMenu } from "#/components/Tabs/utils/useKebabMenu";
|
||||
import { useTabOverflowKebabMenu } from "#/components/Tabs/utils";
|
||||
import { useProxy } from "#/contexts/ProxyContext";
|
||||
import { useClipboard } from "#/hooks/useClipboard";
|
||||
import { useFeatureVisibility } from "#/modules/dashboard/useFeatureVisibility";
|
||||
@@ -161,7 +162,7 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
// This is a bit of a hack on the react-window API to get the scroll position.
|
||||
// If we're scrolled to the bottom, we want to keep the list scrolled to the bottom.
|
||||
// This makes it feel similar to a terminal that auto-scrolls downwards!
|
||||
const handleLogScroll = (props: ListOnScrollProps) => {
|
||||
const handleLogScroll = useCallback((props: ListOnScrollProps) => {
|
||||
if (
|
||||
props.scrollOffset === 0 ||
|
||||
props.scrollUpdateWasRequested ||
|
||||
@@ -178,7 +179,7 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
logListDivRef.current.scrollHeight -
|
||||
(props.scrollOffset + parent.clientHeight);
|
||||
setBottomOfLogs(distanceFromBottom < AGENT_LOG_LINE_HEIGHT);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const devcontainers = useAgentContainers(agent);
|
||||
|
||||
@@ -210,56 +211,59 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
);
|
||||
|
||||
const [selectedLogTab, setSelectedLogTab] = useState("all");
|
||||
const sourceLogTabs = agent.log_sources
|
||||
.filter((logSource) => {
|
||||
// Remove the logSources that have no entries.
|
||||
return agentLogs.some(
|
||||
(log) =>
|
||||
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
|
||||
);
|
||||
})
|
||||
.map((logSource) => ({
|
||||
// Show the icon for the log source if it has one.
|
||||
// In the startup script case, we show a bespoke play icon.
|
||||
startIcon: logSource.icon ? (
|
||||
<ExternalImage
|
||||
src={logSource.icon}
|
||||
alt=""
|
||||
className="size-icon-xs shrink-0"
|
||||
/>
|
||||
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
|
||||
<PlayIcon className="size-icon-xs shrink-0" />
|
||||
) : null,
|
||||
title: logSource.display_name,
|
||||
value: logSource.id,
|
||||
}));
|
||||
const startupScriptLogTab = sourceLogTabs.find(
|
||||
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
|
||||
);
|
||||
const sortedSourceLogTabs = sourceLogTabs
|
||||
.filter((tab) => tab !== startupScriptLogTab)
|
||||
.sort((a, b) => a.title.localeCompare(b.title));
|
||||
const logTabs: {
|
||||
startIcon?: ReactNode;
|
||||
title: string;
|
||||
value: string;
|
||||
}[] = [
|
||||
{
|
||||
title: "All Logs",
|
||||
value: "all",
|
||||
},
|
||||
...(startupScriptLogTab ? [startupScriptLogTab] : []),
|
||||
...sortedSourceLogTabs,
|
||||
];
|
||||
const logTabs = useMemo(() => {
|
||||
const sourceLogTabs = agent.log_sources
|
||||
.filter((logSource) => {
|
||||
// Remove the logSources that have no entries.
|
||||
return agentLogs.some(
|
||||
(log) =>
|
||||
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
|
||||
);
|
||||
})
|
||||
.map((logSource) => ({
|
||||
// Show the icon for the log source if it has one.
|
||||
// In the startup script case, we show a bespoke play icon.
|
||||
startIcon: logSource.icon ? (
|
||||
<ExternalImage
|
||||
src={logSource.icon}
|
||||
alt=""
|
||||
className="size-icon-xs shrink-0"
|
||||
/>
|
||||
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
|
||||
<PlayIcon className="size-icon-xs shrink-0" />
|
||||
) : null,
|
||||
title: logSource.display_name,
|
||||
value: logSource.id,
|
||||
}));
|
||||
const startupScriptLogTab = sourceLogTabs.find(
|
||||
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
|
||||
);
|
||||
const sortedSourceLogTabs = sourceLogTabs
|
||||
.filter((tab) => tab !== startupScriptLogTab)
|
||||
.sort((a, b) => a.title.localeCompare(b.title));
|
||||
return [
|
||||
{
|
||||
title: "All Logs",
|
||||
value: "all",
|
||||
},
|
||||
...(startupScriptLogTab ? [startupScriptLogTab] : []),
|
||||
...sortedSourceLogTabs,
|
||||
] as {
|
||||
startIcon?: React.ReactNode;
|
||||
title: string;
|
||||
value: string;
|
||||
}[];
|
||||
}, [agent.log_sources, agentLogs]);
|
||||
const {
|
||||
containerRef: logTabsListContainerRef,
|
||||
visibleTabs: visibleLogTabs,
|
||||
overflowTabs: overflowLogTabs,
|
||||
getTabMeasureProps,
|
||||
} = useKebabMenu({
|
||||
} = useTabOverflowKebabMenu({
|
||||
tabs: logTabs,
|
||||
enabled: true,
|
||||
isActive: showLogs,
|
||||
alwaysVisibleTabsCount: 1,
|
||||
});
|
||||
const overflowLogTabValuesSet = new Set(
|
||||
overflowLogTabs.map((tab) => tab.value),
|
||||
@@ -275,29 +279,16 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
level: log.level,
|
||||
sourceId: log.source_id,
|
||||
}));
|
||||
const allLogsText = agentLogs.map((log) => log.output).join("\n");
|
||||
const selectedLogsText = selectedLogs.map((log) => log.output).join("\n");
|
||||
const hasSelectedLogs = selectedLogs.length > 0;
|
||||
const hasAnyLogs = agentLogs.length > 0;
|
||||
const { showCopiedSuccess, copyToClipboard } = useClipboard();
|
||||
const downloadableLogSets = logTabs
|
||||
.filter((tab) => tab.value !== "all")
|
||||
.map((tab) => {
|
||||
const logsText = agentLogs
|
||||
.filter((log) => log.source_id === tab.value)
|
||||
.map((log) => log.output)
|
||||
.join("\n");
|
||||
const filenameSuffix = tab.title
|
||||
.toLowerCase()
|
||||
.replaceAll(/[^a-z0-9]+/g, "-")
|
||||
.replaceAll(/(^-|-$)/g, "");
|
||||
return {
|
||||
label: tab.title,
|
||||
filenameSuffix: filenameSuffix || tab.value,
|
||||
logsText,
|
||||
startIcon: tab.startIcon,
|
||||
};
|
||||
});
|
||||
const selectedLogTabTitle =
|
||||
logTabs.find((tab) => tab.value === selectedLogTab)?.title ?? "Logs";
|
||||
const sanitizedTabTitle = selectedLogTabTitle
|
||||
.toLowerCase()
|
||||
.replaceAll(/[^a-z0-9]+/g, "-")
|
||||
.replaceAll(/(^-|-$)/g, "");
|
||||
const logFilenameSuffix = sanitizedTabTitle || "logs";
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -556,9 +547,9 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
</Button>
|
||||
<DownloadSelectedAgentLogsButton
|
||||
agentName={agent.name}
|
||||
logSets={downloadableLogSets}
|
||||
allLogsText={allLogsText}
|
||||
disabled={!hasAnyLogs}
|
||||
filenameSuffix={logFilenameSuffix}
|
||||
logsText={selectedLogsText}
|
||||
disabled={!hasSelectedLogs}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,27 +1,14 @@
|
||||
import { saveAs } from "file-saver";
|
||||
import { ChevronDownIcon, DownloadIcon, PackageIcon } from "lucide-react";
|
||||
import { type FC, type ReactNode, useState } from "react";
|
||||
import { DownloadIcon } from "lucide-react";
|
||||
import { type FC, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { getErrorDetail } from "#/api/errors";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "#/components/DropdownMenu/DropdownMenu";
|
||||
|
||||
type DownloadableLogSet = {
|
||||
label: string;
|
||||
filenameSuffix: string;
|
||||
logsText: string;
|
||||
startIcon?: ReactNode;
|
||||
};
|
||||
|
||||
type DownloadSelectedAgentLogsButtonProps = {
|
||||
agentName: string;
|
||||
logSets: readonly DownloadableLogSet[];
|
||||
allLogsText: string;
|
||||
filenameSuffix: string;
|
||||
logsText: string;
|
||||
disabled?: boolean;
|
||||
download?: (file: Blob, filename: string) => void | Promise<void>;
|
||||
};
|
||||
@@ -30,13 +17,13 @@ export const DownloadSelectedAgentLogsButton: FC<
|
||||
DownloadSelectedAgentLogsButtonProps
|
||||
> = ({
|
||||
agentName,
|
||||
logSets,
|
||||
allLogsText,
|
||||
filenameSuffix,
|
||||
logsText,
|
||||
disabled = false,
|
||||
download = saveAs,
|
||||
}) => {
|
||||
const [isDownloading, setIsDownloading] = useState(false);
|
||||
const downloadLogs = async (logsText: string, filenameSuffix: string) => {
|
||||
const handleDownload = async () => {
|
||||
try {
|
||||
setIsDownloading(true);
|
||||
const file = new Blob([logsText], { type: "text/plain" });
|
||||
@@ -50,40 +37,15 @@ export const DownloadSelectedAgentLogsButton: FC<
|
||||
}
|
||||
};
|
||||
|
||||
const hasAllLogs = allLogsText.length > 0;
|
||||
|
||||
return (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="subtle" size="sm" disabled={disabled || isDownloading}>
|
||||
<DownloadIcon />
|
||||
{isDownloading ? "Downloading..." : "Download logs"}
|
||||
<ChevronDownIcon className="size-icon-sm" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
disabled={!hasAllLogs}
|
||||
onSelect={() => {
|
||||
downloadLogs(allLogsText, "all-logs");
|
||||
}}
|
||||
>
|
||||
<PackageIcon />
|
||||
Download all logs
|
||||
</DropdownMenuItem>
|
||||
{logSets.map((logSet) => (
|
||||
<DropdownMenuItem
|
||||
key={logSet.filenameSuffix}
|
||||
disabled={logSet.logsText.length === 0}
|
||||
onSelect={() => {
|
||||
downloadLogs(logSet.logsText, logSet.filenameSuffix);
|
||||
}}
|
||||
>
|
||||
{logSet.startIcon}
|
||||
<span>Download {logSet.label}</span>
|
||||
</DropdownMenuItem>
|
||||
))}
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
<Button
|
||||
variant="subtle"
|
||||
size="sm"
|
||||
disabled={disabled || isDownloading}
|
||||
onClick={handleDownload}
|
||||
>
|
||||
<DownloadIcon />
|
||||
{isDownloading ? "Downloading..." : "Download logs"}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
import dayjs, { type Dayjs } from "dayjs";
|
||||
import { type FC, useState } from "react";
|
||||
import { useQuery } from "react-query";
|
||||
import { useSearchParams } from "react-router";
|
||||
import { chatCostSummary } from "#/api/queries/chats";
|
||||
import type { DateRangeValue } from "#/components/DateRangePicker/DateRangePicker";
|
||||
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
|
||||
import { useAuthContext } from "#/contexts/auth/AuthProvider";
|
||||
import { AgentAnalyticsPageView } from "./AgentAnalyticsPageView";
|
||||
import { AgentPageHeader } from "./components/AgentPageHeader";
|
||||
|
||||
const createDateRange = (now?: Dayjs) => {
|
||||
const startDateSearchParam = "startDate";
|
||||
const endDateSearchParam = "endDate";
|
||||
|
||||
const getDefaultDateRange = (now?: Dayjs): DateRangeValue => {
|
||||
const end = now ?? dayjs();
|
||||
const start = end.subtract(30, "day");
|
||||
return {
|
||||
startDate: start.toISOString(),
|
||||
endDate: end.toISOString(),
|
||||
rangeLabel: `${start.format("MMM D")} – ${end.format("MMM D, YYYY")}`,
|
||||
startDate: end.subtract(30, "day").toDate(),
|
||||
endDate: end.toDate(),
|
||||
};
|
||||
};
|
||||
|
||||
@@ -24,13 +27,44 @@ interface AgentAnalyticsPageProps {
|
||||
|
||||
const AgentAnalyticsPage: FC<AgentAnalyticsPageProps> = ({ now }) => {
|
||||
const { user } = useAuthContext();
|
||||
const [anchor] = useState<Dayjs>(() => dayjs());
|
||||
const dateRange = createDateRange(now ?? anchor);
|
||||
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
const startDateParam = searchParams.get(startDateSearchParam)?.trim() ?? "";
|
||||
const endDateParam = searchParams.get(endDateSearchParam)?.trim() ?? "";
|
||||
const [defaultDateRange] = useState(() => getDefaultDateRange(now));
|
||||
let dateRange = defaultDateRange;
|
||||
let hasExplicitDateRange = false;
|
||||
|
||||
if (startDateParam && endDateParam) {
|
||||
const parsedStartDate = new Date(startDateParam);
|
||||
const parsedEndDate = new Date(endDateParam);
|
||||
|
||||
if (
|
||||
!Number.isNaN(parsedStartDate.getTime()) &&
|
||||
!Number.isNaN(parsedEndDate.getTime()) &&
|
||||
parsedStartDate.getTime() <= parsedEndDate.getTime()
|
||||
) {
|
||||
dateRange = {
|
||||
startDate: parsedStartDate,
|
||||
endDate: parsedEndDate,
|
||||
};
|
||||
hasExplicitDateRange = true;
|
||||
}
|
||||
}
|
||||
|
||||
const onDateRangeChange = (value: DateRangeValue) => {
|
||||
setSearchParams((prev) => {
|
||||
const next = new URLSearchParams(prev);
|
||||
next.set(startDateSearchParam, value.startDate.toISOString());
|
||||
next.set(endDateSearchParam, value.endDate.toISOString());
|
||||
return next;
|
||||
});
|
||||
};
|
||||
|
||||
const summaryQuery = useQuery({
|
||||
...chatCostSummary(user?.id ?? "me", {
|
||||
start_date: dateRange.startDate,
|
||||
end_date: dateRange.endDate,
|
||||
start_date: dateRange.startDate.toISOString(),
|
||||
end_date: dateRange.endDate.toISOString(),
|
||||
}),
|
||||
enabled: Boolean(user?.id),
|
||||
});
|
||||
@@ -43,7 +77,9 @@ const AgentAnalyticsPage: FC<AgentAnalyticsPageProps> = ({ now }) => {
|
||||
isLoading={summaryQuery.isLoading}
|
||||
error={summaryQuery.error}
|
||||
onRetry={() => void summaryQuery.refetch()}
|
||||
rangeLabel={dateRange.rangeLabel}
|
||||
dateRange={dateRange}
|
||||
onDateRangeChange={onDateRangeChange}
|
||||
hasExplicitDateRange={hasExplicitDateRange}
|
||||
/>
|
||||
</ScrollArea>
|
||||
);
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
import { BarChart3Icon } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import type { ChatCostSummary } from "#/api/typesGenerated";
|
||||
import {
|
||||
DateRangePicker,
|
||||
type DateRangeValue,
|
||||
} from "#/components/DateRangePicker/DateRangePicker";
|
||||
import { ChatCostSummaryView } from "./components/ChatCostSummaryView";
|
||||
import { SectionHeader } from "./components/SectionHeader";
|
||||
import { toInclusiveDateRange } from "./utils/dateRange";
|
||||
|
||||
interface AgentAnalyticsPageViewProps {
|
||||
summary: ChatCostSummary | undefined;
|
||||
isLoading: boolean;
|
||||
error: unknown;
|
||||
onRetry: () => void;
|
||||
rangeLabel: string;
|
||||
dateRange: DateRangeValue;
|
||||
onDateRangeChange: (value: DateRangeValue) => void;
|
||||
hasExplicitDateRange: boolean;
|
||||
}
|
||||
|
||||
export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
|
||||
@@ -17,8 +23,15 @@ export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
|
||||
isLoading,
|
||||
error,
|
||||
onRetry,
|
||||
rangeLabel,
|
||||
dateRange,
|
||||
onDateRangeChange,
|
||||
hasExplicitDateRange,
|
||||
}) => {
|
||||
const displayDateRange = toInclusiveDateRange(
|
||||
dateRange,
|
||||
hasExplicitDateRange,
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col p-4 pt-8">
|
||||
<div className="mx-auto w-full max-w-3xl">
|
||||
@@ -26,10 +39,10 @@ export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
|
||||
label="Analytics"
|
||||
description="Review your personal Coder Agents usage and cost breakdowns."
|
||||
action={
|
||||
<div className="flex items-center gap-2 text-xs text-content-secondary">
|
||||
<BarChart3Icon className="h-4 w-4" />
|
||||
<span>{rangeLabel}</span>
|
||||
</div>
|
||||
<DateRangePicker
|
||||
value={displayDateRange}
|
||||
onChange={onDateRangeChange}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
|
||||
|
||||
@@ -646,24 +646,17 @@ const AgentChatPage: FC = () => {
|
||||
const isRegenerateTitleDisabled = isArchived || isRegeneratingThisChat;
|
||||
const chatLastModelConfigID = chatRecord?.last_model_config_id;
|
||||
|
||||
// Destructure mutation results directly so the React Compiler
|
||||
// tracks stable primitives/functions instead of the whole result
|
||||
// object (TanStack Query v5 recreates it every render via object
|
||||
// spread). Keeping no intermediate variable prevents future code
|
||||
// from accidentally closing over the unstable object.
|
||||
const { isPending: isSendPending, mutateAsync: sendMessage } = useMutation(
|
||||
const sendMutation = useMutation(
|
||||
createChatMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
const { isPending: isEditPending, mutateAsync: editMessage } = useMutation(
|
||||
editChatMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
const { isPending: isInterruptPending, mutateAsync: interrupt } = useMutation(
|
||||
const editMutation = useMutation(editChatMessage(queryClient, agentId ?? ""));
|
||||
const interruptMutation = useMutation(
|
||||
interruptChat(queryClient, agentId ?? ""),
|
||||
);
|
||||
const { mutateAsync: deleteQueuedMessage } = useMutation(
|
||||
const deleteQueuedMutation = useMutation(
|
||||
deleteChatQueuedMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
const { mutateAsync: promoteQueuedMessage } = useMutation(
|
||||
const promoteQueuedMutation = useMutation(
|
||||
promoteChatQueuedMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
|
||||
@@ -761,7 +754,9 @@ const AgentChatPage: FC = () => {
|
||||
hasUserFixableModelProviders,
|
||||
});
|
||||
const isSubmissionPending =
|
||||
isSendPending || isEditPending || isInterruptPending;
|
||||
sendMutation.isPending ||
|
||||
editMutation.isPending ||
|
||||
interruptMutation.isPending;
|
||||
const isInputDisabled = !hasModelOptions || isArchived;
|
||||
|
||||
const handleUsageLimitError = (error: unknown): void => {
|
||||
@@ -847,7 +842,7 @@ const AgentChatPage: FC = () => {
|
||||
setPendingEditMessageId(editedMessageID);
|
||||
scrollToBottomRef.current?.();
|
||||
try {
|
||||
await editMessage({
|
||||
await editMutation.mutateAsync({
|
||||
messageId: editedMessageID,
|
||||
req: request,
|
||||
});
|
||||
@@ -878,9 +873,9 @@ const AgentChatPage: FC = () => {
|
||||
// For queued sends the WebSocket status events handle
|
||||
// clearing; for non-queued sends we clear explicitly
|
||||
// below. Clearing eagerly causes a visible cutoff.
|
||||
let response: Awaited<ReturnType<typeof sendMessage>>;
|
||||
let response: Awaited<ReturnType<typeof sendMutation.mutateAsync>>;
|
||||
try {
|
||||
response = await sendMessage(request);
|
||||
response = await sendMutation.mutateAsync(request);
|
||||
} catch (error) {
|
||||
handleUsageLimitError(error);
|
||||
throw error;
|
||||
@@ -913,10 +908,10 @@ const AgentChatPage: FC = () => {
|
||||
};
|
||||
|
||||
const handleInterrupt = () => {
|
||||
if (!agentId || isInterruptPending) {
|
||||
if (!agentId || interruptMutation.isPending) {
|
||||
return;
|
||||
}
|
||||
void interrupt();
|
||||
void interruptMutation.mutateAsync();
|
||||
};
|
||||
|
||||
const handleDeleteQueuedMessage = async (id: number) => {
|
||||
@@ -925,7 +920,7 @@ const AgentChatPage: FC = () => {
|
||||
previousQueuedMessages.filter((message) => message.id !== id),
|
||||
);
|
||||
try {
|
||||
await deleteQueuedMessage(id);
|
||||
await deleteQueuedMutation.mutateAsync(id);
|
||||
} catch (error) {
|
||||
store.setQueuedMessages(previousQueuedMessages);
|
||||
throw error;
|
||||
@@ -946,7 +941,7 @@ const AgentChatPage: FC = () => {
|
||||
store.clearStreamError();
|
||||
store.setChatStatus("pending");
|
||||
try {
|
||||
const promotedMessage = await promoteQueuedMessage(id);
|
||||
const promotedMessage = await promoteQueuedMutation.mutateAsync(id);
|
||||
// Insert the promoted message into the store and cache
|
||||
// immediately so it appears in the timeline without
|
||||
// waiting for the WebSocket to deliver it.
|
||||
@@ -995,8 +990,7 @@ const AgentChatPage: FC = () => {
|
||||
? `ssh ${workspaceAgent.name}.${workspace.name}.${workspace.owner_name}.${sshConfigQuery.data.hostname_suffix}`
|
||||
: undefined;
|
||||
|
||||
// See mutation destructuring comment above (React Compiler).
|
||||
const { mutate: generateKey } = useMutation({
|
||||
const generateKeyMutation = useMutation({
|
||||
mutationFn: () => API.getApiKey(),
|
||||
});
|
||||
|
||||
@@ -1011,7 +1005,7 @@ const AgentChatPage: FC = () => {
|
||||
const repoRoots = Array.from(gitWatcher.repositories.keys()).sort();
|
||||
const folder = repoRoots[0] ?? workspaceAgent.expanded_directory;
|
||||
|
||||
generateKey(undefined, {
|
||||
generateKeyMutation.mutate(undefined, {
|
||||
onSuccess: ({ key }) => {
|
||||
location.href = getVSCodeHref(editor, {
|
||||
owner: workspace.owner_name,
|
||||
@@ -1147,7 +1141,7 @@ const AgentChatPage: FC = () => {
|
||||
compressionThreshold={compressionThreshold}
|
||||
isInputDisabled={isInputDisabled}
|
||||
isSubmissionPending={isSubmissionPending}
|
||||
isInterruptPending={isInterruptPending}
|
||||
isInterruptPending={interruptMutation.isPending}
|
||||
isSidebarCollapsed={isSidebarCollapsed}
|
||||
onToggleSidebarCollapsed={onToggleSidebarCollapsed}
|
||||
showSidebarPanel={showSidebarPanel}
|
||||
|
||||
@@ -372,7 +372,7 @@ export const DefaultAutostopNotVisibleToNonAdmin: Story = {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
// Personal Instructions should be visible.
|
||||
await canvas.findByText("Personal Instructions");
|
||||
await canvas.findByText("Instructions");
|
||||
|
||||
// Admin-only sections should not be present.
|
||||
expect(canvas.queryByText("Workspace Autostop Fallback")).toBeNull();
|
||||
@@ -412,7 +412,7 @@ export const InvisibleUnicodeWarningUserPrompt: Story = {
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
await canvas.findByText("Personal Instructions");
|
||||
await canvas.findByText("Instructions");
|
||||
|
||||
const alert = await canvas.findByText(/invisible Unicode/);
|
||||
expect(alert).toBeInTheDocument();
|
||||
@@ -454,7 +454,7 @@ export const NoWarningForCleanPrompt: Story = {
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
await canvas.findByText("Personal Instructions");
|
||||
await canvas.findByText("Instructions");
|
||||
await canvas.findByText("System Instructions");
|
||||
|
||||
expect(canvas.queryByText(/invisible Unicode/)).toBeNull();
|
||||
|
||||
@@ -9,6 +9,13 @@ import { Switch } from "#/components/Switch/Switch";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { countInvisibleCharacters } from "#/utils/invisibleUnicode";
|
||||
import { AdminBadge } from "./components/AdminBadge";
|
||||
import {
|
||||
CollapsibleSection,
|
||||
CollapsibleSectionContent,
|
||||
CollapsibleSectionDescription,
|
||||
CollapsibleSectionHeader,
|
||||
CollapsibleSectionTitle,
|
||||
} from "./components/CollapsibleSection";
|
||||
import { DurationField } from "./components/DurationField/DurationField";
|
||||
import { SectionHeader } from "./components/SectionHeader";
|
||||
import { TextPreviewDialog } from "./components/TextPreviewDialog";
|
||||
@@ -246,142 +253,43 @@ export const AgentSettingsBehaviorPageView: FC<
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="space-y-6">
|
||||
<SectionHeader
|
||||
label="Behavior"
|
||||
description="Custom instructions that shape how the agent responds in your conversations."
|
||||
/>
|
||||
{/* ── Personal prompt (always visible) ── */}
|
||||
<form
|
||||
className="space-y-2"
|
||||
onSubmit={(event) => void handleSaveUserPrompt(event)}
|
||||
>
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Personal Instructions
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Applied to all your conversations. Only visible to you.
|
||||
</p>
|
||||
<TextareaAutosize
|
||||
className={cn(
|
||||
textareaBaseClassName,
|
||||
isUserPromptOverflowing && textareaOverflowClassName,
|
||||
)}
|
||||
placeholder="Additional behavior, style, and tone preferences"
|
||||
value={userPromptDraft}
|
||||
onChange={(event) => setLocalUserEdit(event.target.value)}
|
||||
onHeightChange={(height) =>
|
||||
setIsUserPromptOverflowing(height >= textareaMaxHeight)
|
||||
}
|
||||
disabled={isPromptSaving}
|
||||
minRows={1}
|
||||
/>
|
||||
{userInvisibleCharCount > 0 && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
This text contains {userInvisibleCharCount} invisible Unicode{" "}
|
||||
{userInvisibleCharCount !== 1 ? "characters" : "character"} that
|
||||
could hide content. These will be stripped on save.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={() => setLocalUserEdit("")}
|
||||
disabled={isPromptSaving || !userPromptDraft}
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={isPromptSaving || !isUserPromptDirty}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
{isSaveUserPromptError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save personal instructions.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
<UserCompactionThresholdSettings
|
||||
modelConfigs={modelConfigsData ?? []}
|
||||
modelConfigsError={modelConfigsError}
|
||||
isLoadingModelConfigs={isLoadingModelConfigs}
|
||||
thresholds={thresholds}
|
||||
isThresholdsLoading={isThresholdsLoading}
|
||||
thresholdsError={thresholdsError}
|
||||
onSaveThreshold={onSaveThreshold}
|
||||
onResetThreshold={onResetThreshold}
|
||||
/>
|
||||
{/* ── Admin system prompt (admin only) ── */}
|
||||
{canSetSystemPrompt && (
|
||||
<>
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
{/* Personal prompt (always visible) */}
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Instructions</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>Applied to all your conversations. Only visible to you.</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<form
|
||||
className="space-y-2"
|
||||
onSubmit={(event) => void handleSaveSystemPrompt(event)}
|
||||
onSubmit={(event) => void handleSaveUserPrompt(event)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
System Instructions
|
||||
</h3>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex min-w-0 items-center gap-2 text-xs font-medium text-content-primary">
|
||||
<span>Include Coder Agents default system prompt</span>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="subtle"
|
||||
type="button"
|
||||
onClick={() => setShowDefaultPromptPreview(true)}
|
||||
disabled={!hasLoadedSystemPrompt}
|
||||
className="min-w-0 px-0 text-content-link hover:text-content-link"
|
||||
>
|
||||
Preview
|
||||
</Button>
|
||||
</div>
|
||||
<Switch
|
||||
checked={includeDefaultDraft}
|
||||
onCheckedChange={setLocalIncludeDefault}
|
||||
aria-label="Include Coder Agents default system prompt"
|
||||
disabled={isSystemPromptDisabled}
|
||||
/>
|
||||
</div>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
{includeDefaultDraft
|
||||
? "The built-in Coder Agents prompt is prepended. Additional instructions below are appended."
|
||||
: "Only the additional instructions below are used. When empty, no deployment-wide system prompt is sent."}
|
||||
</p>
|
||||
<TextareaAutosize
|
||||
className={cn(
|
||||
textareaBaseClassName,
|
||||
isSystemPromptOverflowing && textareaOverflowClassName,
|
||||
isUserPromptOverflowing && textareaOverflowClassName,
|
||||
)}
|
||||
placeholder="Additional instructions for all users"
|
||||
value={systemPromptDraft}
|
||||
onChange={(event) => setLocalEdit(event.target.value)}
|
||||
placeholder="Additional behavior, style, and tone preferences"
|
||||
value={userPromptDraft}
|
||||
onChange={(event) => setLocalUserEdit(event.target.value)}
|
||||
onHeightChange={(height) =>
|
||||
setIsSystemPromptOverflowing(height >= textareaMaxHeight)
|
||||
setIsUserPromptOverflowing(height >= textareaMaxHeight)
|
||||
}
|
||||
disabled={isSystemPromptDisabled}
|
||||
disabled={isPromptSaving}
|
||||
minRows={1}
|
||||
/>
|
||||
{systemInvisibleCharCount > 0 && (
|
||||
{userInvisibleCharCount > 0 && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
This text contains {systemInvisibleCharCount} invisible
|
||||
Unicode{" "}
|
||||
{systemInvisibleCharCount !== 1 ? "characters" : "character"}{" "}
|
||||
that could hide content. These will be stripped on save.
|
||||
This text contains {userInvisibleCharCount} invisible Unicode{" "}
|
||||
{userInvisibleCharCount !== 1 ? "characters" : "character"} that
|
||||
could hide content. These will be stripped on save.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
@@ -390,164 +298,304 @@ export const AgentSettingsBehaviorPageView: FC<
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={() => setLocalEdit("")}
|
||||
disabled={isSystemPromptDisabled || !systemPromptDraft}
|
||||
onClick={() => setLocalUserEdit("")}
|
||||
disabled={isPromptSaving || !userPromptDraft}
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={isSystemPromptDisabled || !isSystemPromptDirty}
|
||||
disabled={isPromptSaving || !isUserPromptDirty}
|
||||
>
|
||||
Save
|
||||
</Button>{" "}
|
||||
</Button>
|
||||
</div>
|
||||
{isSaveSystemPromptError && (
|
||||
{isSaveUserPromptError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save system prompt.
|
||||
Failed to save personal instructions.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Virtual Desktop
|
||||
</h3>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
|
||||
<p className="m-0">
|
||||
Allow agents to use a virtual, graphical desktop within
|
||||
workspaces. Requires the{" "}
|
||||
<Link
|
||||
href="https://registry.coder.com/modules/coder/portabledesktop"
|
||||
target="_blank"
|
||||
size="sm"
|
||||
>
|
||||
portabledesktop module
|
||||
</Link>{" "}
|
||||
to be installed in the workspace and the Anthropic provider to
|
||||
be configured.
|
||||
</p>
|
||||
<p className="mt-2 mb-0 font-semibold text-content-secondary">
|
||||
Warning: This is a work-in-progress feature, and you're likely
|
||||
to encounter bugs if you enable it.
|
||||
</p>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
|
||||
{/* Context compaction */}
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Context compaction</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>Control when conversation context is automatically summarized for each model. Setting 100% means the conversation will never auto-compact.</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<UserCompactionThresholdSettings
|
||||
hideHeader
|
||||
modelConfigs={modelConfigsData ?? []}
|
||||
modelConfigsError={modelConfigsError}
|
||||
isLoadingModelConfigs={isLoadingModelConfigs}
|
||||
thresholds={thresholds}
|
||||
isThresholdsLoading={isThresholdsLoading}
|
||||
thresholdsError={thresholdsError}
|
||||
onSaveThreshold={onSaveThreshold}
|
||||
onResetThreshold={onResetThreshold}
|
||||
/>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
|
||||
{/* Admin system prompt (admin only) */}
|
||||
{canSetSystemPrompt && (
|
||||
<>
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>System Instructions</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<Switch
|
||||
checked={desktopEnabled}
|
||||
onCheckedChange={(checked) =>
|
||||
onSaveDesktopEnabled({ enable_desktop: checked })
|
||||
}
|
||||
aria-label="Enable"
|
||||
disabled={isSavingDesktopEnabled}
|
||||
/>
|
||||
</div>
|
||||
{isSaveDesktopEnabledError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save desktop setting.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
<form
|
||||
className="space-y-2"
|
||||
onSubmit={(event) => void handleSaveChatWorkspaceTTL(event)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Workspace Autostop Fallback
|
||||
</h3>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<p className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
|
||||
Set a default autostop for agent-created workspaces that don't
|
||||
have one defined in their template. Template-defined autostop
|
||||
rules always take precedence. Active conversations will extend
|
||||
the stop time.
|
||||
</p>
|
||||
<Switch
|
||||
checked={isAutostopEnabled}
|
||||
onCheckedChange={handleToggleAutostop}
|
||||
aria-label="Enable default autostop"
|
||||
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
|
||||
/>{" "}
|
||||
</div>
|
||||
{isAutostopEnabled && (
|
||||
<DurationField
|
||||
valueMs={ttlMs}
|
||||
onChange={handleTTLChange}
|
||||
label="Autostop Fallback"
|
||||
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
|
||||
error={isTTLOverMax || isTTLZero}
|
||||
helperText={
|
||||
isTTLZero
|
||||
? "Duration must be greater than zero."
|
||||
: isTTLOverMax
|
||||
? "Must not exceed 30 days (720 hours)."
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{isAutostopEnabled && (
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={
|
||||
isSavingWorkspaceTTL ||
|
||||
!isTTLDirty ||
|
||||
isTTLOverMax ||
|
||||
isTTLZero
|
||||
<CollapsibleSectionDescription>
|
||||
{includeDefaultDraft
|
||||
? "The built-in Coder Agents prompt is prepended. Additional instructions below are appended."
|
||||
: "Only the additional instructions below are used. When empty, no deployment-wide system prompt is sent."}
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<form
|
||||
className="space-y-2"
|
||||
onSubmit={(event) => void handleSaveSystemPrompt(event)}
|
||||
>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex min-w-0 items-center gap-2 text-xs font-medium text-content-primary">
|
||||
<span>Include Coder Agents default system prompt</span>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="subtle"
|
||||
type="button"
|
||||
onClick={() => setShowDefaultPromptPreview(true)}
|
||||
disabled={!hasLoadedSystemPrompt}
|
||||
className="min-w-0 px-0 text-content-link hover:text-content-link"
|
||||
>
|
||||
Preview
|
||||
</Button>
|
||||
</div>
|
||||
<Switch
|
||||
checked={includeDefaultDraft}
|
||||
onCheckedChange={setLocalIncludeDefault}
|
||||
aria-label="Include Coder Agents default system prompt"
|
||||
disabled={isSystemPromptDisabled}
|
||||
/>
|
||||
</div>
|
||||
<TextareaAutosize
|
||||
className={cn(
|
||||
textareaBaseClassName,
|
||||
isSystemPromptOverflowing && textareaOverflowClassName,
|
||||
)}
|
||||
placeholder="Additional instructions for all users"
|
||||
value={systemPromptDraft}
|
||||
onChange={(event) => setLocalEdit(event.target.value)}
|
||||
onHeightChange={(height) =>
|
||||
setIsSystemPromptOverflowing(height >= textareaMaxHeight)
|
||||
}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
disabled={isSystemPromptDisabled}
|
||||
minRows={1}
|
||||
/>
|
||||
{systemInvisibleCharCount > 0 && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
This text contains {systemInvisibleCharCount} invisible
|
||||
Unicode{" "}
|
||||
{systemInvisibleCharCount !== 1
|
||||
? "characters"
|
||||
: "character"}{" "}
|
||||
that could hide content. These will be stripped on save.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={() => setLocalEdit("")}
|
||||
disabled={isSystemPromptDisabled || !systemPromptDraft}
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={isSystemPromptDisabled || !isSystemPromptDirty}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
{isSaveSystemPromptError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save system prompt.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>Virtual Desktop</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
)}
|
||||
{isSaveWorkspaceTTLError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save autostop setting.
|
||||
</p>
|
||||
)}
|
||||
{isWorkspaceTTLLoadError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to load autostop setting.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="flex items-start gap-4">
|
||||
<div className="min-w-0 flex-1 space-y-1">
|
||||
<span className="text-sm font-medium text-content-primary">
|
||||
Enable virtual desktop
|
||||
</span>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Allow agents to use a virtual, graphical desktop within
|
||||
workspaces. Requires the{" "}
|
||||
<Link
|
||||
href="https://registry.coder.com/modules/coder/portabledesktop"
|
||||
target="_blank"
|
||||
size="sm"
|
||||
>
|
||||
portabledesktop module
|
||||
</Link>{" "}
|
||||
to be installed in the workspace and the Anthropic provider to
|
||||
be configured.
|
||||
</p>
|
||||
<p className="m-0 text-xs text-content-secondary font-semibold">
|
||||
Warning: This is a work-in-progress feature, and you're likely
|
||||
to encounter bugs if you enable it.
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
checked={desktopEnabled}
|
||||
onCheckedChange={(checked) =>
|
||||
onSaveDesktopEnabled({ enable_desktop: checked })
|
||||
}
|
||||
aria-label="Enable"
|
||||
disabled={isSavingDesktopEnabled}
|
||||
/>
|
||||
</div>{" "}
|
||||
{isSaveDesktopEnabledError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save desktop setting.
|
||||
</p>
|
||||
)}
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>Workspace Autostop Fallback</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="flex items-start gap-4">
|
||||
<div className="min-w-0 flex-1 space-y-1">
|
||||
<span className="text-sm font-medium text-content-primary">
|
||||
Enable autostop fallback
|
||||
</span>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Set a default autostop for agent-created workspaces that don't
|
||||
have one defined in their template. Template-defined autostop
|
||||
rules always take precedence. Active conversations will extend
|
||||
the stop time.
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
checked={isAutostopEnabled}
|
||||
onCheckedChange={handleToggleAutostop}
|
||||
aria-label="Enable default autostop"
|
||||
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
|
||||
/>
|
||||
</div>
|
||||
<form
|
||||
className="space-y-2 pt-4"
|
||||
onSubmit={(event) => void handleSaveChatWorkspaceTTL(event)}
|
||||
>
|
||||
{" "}
|
||||
{isAutostopEnabled && (
|
||||
<DurationField
|
||||
valueMs={ttlMs}
|
||||
onChange={handleTTLChange}
|
||||
label="Autostop Fallback"
|
||||
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
|
||||
error={isTTLOverMax || isTTLZero}
|
||||
helperText={
|
||||
isTTLZero
|
||||
? "Duration must be greater than zero."
|
||||
: isTTLOverMax
|
||||
? "Must not exceed 30 days (720 hours)."
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{isAutostopEnabled && (
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={
|
||||
isSavingWorkspaceTTL ||
|
||||
!isTTLDirty ||
|
||||
isTTLOverMax ||
|
||||
isTTLZero
|
||||
}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
{isSaveWorkspaceTTLError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save autostop setting.
|
||||
</p>
|
||||
)}
|
||||
{isWorkspaceTTLLoadError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to load autostop setting.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
</>
|
||||
)}
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
{/* ── Kyleosophy toggle (always visible) ── */}
|
||||
<div className="space-y-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Kyleosophy
|
||||
</h3>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<p className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
|
||||
Replace the standard completion chime. IYKYK.
|
||||
{kylesophyForced && (
|
||||
<span className="ml-1 font-semibold">
|
||||
Kyleosophy is mandatory on <code>dev.coder.com</code>.
|
||||
|
||||
{/* Kyleosophy toggle (always visible) */}
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Kyleosophy</CollapsibleSectionTitle>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="flex items-start gap-4">
|
||||
<div className="min-w-0 flex-1 space-y-1">
|
||||
<span className="text-sm font-medium text-content-primary">
|
||||
Enable Kyleosophy
|
||||
</span>
|
||||
)}
|
||||
</p>
|
||||
<Switch
|
||||
checked={kylesophyEnabled}
|
||||
onCheckedChange={(checked) => {
|
||||
setKylesophyEnabled(checked);
|
||||
setLocalKylesophy(checked);
|
||||
}}
|
||||
aria-label="Enable Kyleosophy"
|
||||
disabled={kylesophyForced}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
{" "}
|
||||
Replace the standard completion chime. IYKYK.
|
||||
{kylesophyForced && (
|
||||
<span className="ml-1 font-semibold">
|
||||
Kyleosophy is mandatory on <code>dev.coder.com</code>.
|
||||
</span>
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
checked={kylesophyEnabled}
|
||||
onCheckedChange={(checked) => {
|
||||
setKylesophyEnabled(checked);
|
||||
setLocalKylesophy(checked);
|
||||
}}
|
||||
aria-label="Enable Kyleosophy"
|
||||
disabled={kylesophyForced}
|
||||
/>
|
||||
</div>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
{showDefaultPromptPreview && (
|
||||
<TextPreviewDialog
|
||||
content={defaultSystemPrompt}
|
||||
@@ -555,6 +603,6 @@ export const AgentSettingsBehaviorPageView: FC<
|
||||
onClose={() => setShowDefaultPromptPreview(false)}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import { BookOpenIcon, LoaderIcon, TriangleAlertIcon } from "lucide-react";
|
||||
import type React from "react";
|
||||
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "#/components/Tooltip/Tooltip";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { Response } from "../Response";
|
||||
import { ToolCollapsible } from "./ToolCollapsible";
|
||||
import type { ToolStatus } from "./utils";
|
||||
|
||||
export const ReadSkillTool: React.FC<{
|
||||
label: string;
|
||||
body: string;
|
||||
status: ToolStatus;
|
||||
isError: boolean;
|
||||
errorMessage?: string;
|
||||
}> = ({ label, body, status, isError, errorMessage }) => {
|
||||
const hasContent = body.length > 0;
|
||||
const isRunning = status === "running";
|
||||
|
||||
return (
|
||||
<ToolCollapsible
|
||||
className="w-full"
|
||||
hasContent={hasContent}
|
||||
header={
|
||||
<>
|
||||
<BookOpenIcon className="h-4 w-4 shrink-0 text-content-secondary" />
|
||||
<span className={cn("text-sm", "text-content-secondary")}>
|
||||
{isRunning ? `Reading ${label}…` : `Read ${label}`}
|
||||
</span>
|
||||
{isError && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<TriangleAlertIcon className="h-3.5 w-3.5 shrink-0 text-content-secondary" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{errorMessage || "Failed to read skill"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
{isRunning && (
|
||||
<LoaderIcon className="h-3.5 w-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" />
|
||||
)}
|
||||
</>
|
||||
}
|
||||
>
|
||||
{body && (
|
||||
<ScrollArea
|
||||
className="mt-1.5 rounded-md border border-solid border-border-default"
|
||||
viewportClassName="max-h-64"
|
||||
scrollBarClassName="w-1.5"
|
||||
>
|
||||
<div className="px-3 py-2">
|
||||
<Response>{body}</Response>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
)}
|
||||
</ToolCollapsible>
|
||||
);
|
||||
};
|
||||
@@ -1342,12 +1342,6 @@ export const ReadSkillCompleted: Story = {
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
expect(canvas.getByText(/Read skill deep-review/)).toBeInTheDocument();
|
||||
// Expand the collapsible to verify markdown body renders.
|
||||
const toggle = canvas.getByRole("button");
|
||||
await userEvent.click(toggle);
|
||||
await waitFor(() => {
|
||||
expect(canvas.getByText("Deep Review Skill")).toBeInTheDocument();
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1399,12 +1393,6 @@ export const ReadSkillFileCompleted: Story = {
|
||||
expect(
|
||||
canvas.getByText(/Read deep-review\/roles\/security-reviewer\.md/),
|
||||
).toBeInTheDocument();
|
||||
// Expand the collapsible to verify markdown content renders.
|
||||
const toggle = canvas.getByRole("button");
|
||||
await userEvent.click(toggle);
|
||||
await waitFor(() => {
|
||||
expect(canvas.getByText("Security Reviewer Role")).toBeInTheDocument();
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import { ListTemplatesTool } from "./ListTemplatesTool";
|
||||
import { ProcessOutputTool } from "./ProcessOutputTool";
|
||||
import { ProposePlanTool } from "./ProposePlanTool";
|
||||
import { ReadFileTool } from "./ReadFileTool";
|
||||
import { ReadSkillTool } from "./ReadSkillTool";
|
||||
import { ReadTemplateTool } from "./ReadTemplateTool";
|
||||
import { SubagentTool } from "./SubagentTool";
|
||||
import { ToolCollapsible } from "./ToolCollapsible";
|
||||
@@ -211,55 +210,6 @@ const ReadFileRenderer: FC<ToolRendererProps> = ({
|
||||
);
|
||||
};
|
||||
|
||||
const ReadSkillRenderer: FC<ToolRendererProps> = ({
|
||||
status,
|
||||
args,
|
||||
result,
|
||||
isError,
|
||||
}) => {
|
||||
const parsedArgs = parseArgs(args);
|
||||
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
|
||||
const rec = asRecord(result);
|
||||
const body = rec ? asString(rec.body) : "";
|
||||
|
||||
return (
|
||||
<ReadSkillTool
|
||||
label={skillName ? `skill ${skillName}` : "skill"}
|
||||
body={body}
|
||||
status={status}
|
||||
isError={isError}
|
||||
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const ReadSkillFileRenderer: FC<ToolRendererProps> = ({
|
||||
status,
|
||||
args,
|
||||
result,
|
||||
isError,
|
||||
}) => {
|
||||
const parsedArgs = parseArgs(args);
|
||||
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
|
||||
const filePath = parsedArgs ? asString(parsedArgs.path) : "";
|
||||
const label =
|
||||
skillName && filePath
|
||||
? `${skillName}/${filePath}`
|
||||
: skillName || filePath || "skill file";
|
||||
const rec = asRecord(result);
|
||||
const content = rec ? asString(rec.content) : "";
|
||||
|
||||
return (
|
||||
<ReadSkillTool
|
||||
label={label}
|
||||
body={content}
|
||||
status={status}
|
||||
isError={isError}
|
||||
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const WriteFileRenderer: FC<ToolRendererProps> = ({
|
||||
status,
|
||||
args,
|
||||
@@ -717,8 +667,6 @@ const toolRenderers: Record<string, FC<ToolRendererProps>> = {
|
||||
create_workspace: CreateWorkspaceRenderer,
|
||||
list_templates: ListTemplatesRenderer,
|
||||
read_template: ReadTemplateRenderer,
|
||||
read_skill: ReadSkillRenderer,
|
||||
read_skill_file: ReadSkillFileRenderer,
|
||||
spawn_agent: SubagentRenderer,
|
||||
wait_agent: SubagentRenderer,
|
||||
message_agent: SubagentRenderer,
|
||||
|
||||
+1
-4
@@ -255,7 +255,7 @@ export const ProviderAccordionCards: Story = {
|
||||
expect(body.queryByText("OpenAI")).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.click(body.getByRole("button", { name: /OpenRouter/i }));
|
||||
await expect(await body.findByLabelText("Base URL")).toBeInTheDocument();
|
||||
await expect(body.getByLabelText("Base URL")).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
@@ -462,9 +462,6 @@ export const CreateAndUpdateProvider: Story = {
|
||||
await waitFor(() => {
|
||||
expect(args.onCreateProvider).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
|
||||
});
|
||||
expect(args.onCreateProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: "openai",
|
||||
|
||||
@@ -285,7 +285,7 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
|
||||
const modelConfigsUnavailable = modelConfigsData === null;
|
||||
|
||||
return (
|
||||
<div className={cn("flex min-h-full flex-col space-y-3", className)}>
|
||||
<div className={cn("flex min-h-full flex-col space-y-6", className)}>
|
||||
{isLoading && (
|
||||
<div className="flex items-center gap-1.5 text-xs text-content-secondary">
|
||||
<Spinner className="h-4 w-4" loading />
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
import { useFormik } from "formik";
|
||||
import {
|
||||
ChevronDownIcon,
|
||||
ChevronLeftIcon,
|
||||
ChevronRightIcon,
|
||||
InfoIcon,
|
||||
PencilIcon,
|
||||
} from "lucide-react";
|
||||
import { ChevronLeftIcon, InfoIcon, PencilIcon } from "lucide-react";
|
||||
import { type FC, useState } from "react";
|
||||
import * as Yup from "yup";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
@@ -41,6 +35,13 @@ import {
|
||||
} from "#/components/Tooltip/Tooltip";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { getFormHelpers } from "#/utils/formUtils";
|
||||
import {
|
||||
CollapsibleSection,
|
||||
CollapsibleSectionContent,
|
||||
CollapsibleSectionDescription,
|
||||
CollapsibleSectionHeader,
|
||||
CollapsibleSectionTitle,
|
||||
} from "../CollapsibleSection";
|
||||
import type { ProviderState } from "./ChatModelAdminPanel";
|
||||
import {
|
||||
GeneralModelConfigFields,
|
||||
@@ -116,9 +117,6 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
}) => {
|
||||
const isEditing = Boolean(editingModel);
|
||||
const isDefaultModel = isEditing && editingModel?.is_default === true;
|
||||
const [showAdvanced, setShowAdvanced] = useState(false);
|
||||
const [showPricing, setShowPricing] = useState(false);
|
||||
const [showProviderConfig, setShowProviderConfig] = useState(false);
|
||||
const [confirmingDelete, setConfirmingDelete] = useState(false);
|
||||
|
||||
const canManageModels = Boolean(
|
||||
@@ -488,29 +486,18 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
</div>
|
||||
|
||||
{/* Usage Tracking */}
|
||||
<div className="border-0 border-t border-solid border-border pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowPricing((v) => !v)}
|
||||
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
|
||||
>
|
||||
<div>
|
||||
<h3 className="m-0 text-sm font-medium text-content-primary">
|
||||
Cost Tracking{" "}
|
||||
</h3>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Set per-token pricing so Coder can track costs and enforce
|
||||
spending limits.
|
||||
</p>
|
||||
</div>
|
||||
{showPricing ? (
|
||||
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
) : (
|
||||
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
)}
|
||||
</button>
|
||||
{showPricing && (
|
||||
<div className="grid grid-cols-2 gap-3 pt-3 sm:grid-cols-4">
|
||||
<CollapsibleSection variant="inline" defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle as="h3">
|
||||
Cost Tracking
|
||||
</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Set per-token pricing so Coder can track costs and enforce
|
||||
spending limits.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="grid grid-cols-2 gap-3 sm:grid-cols-4">
|
||||
<PricingModelConfigFields
|
||||
provider={selectedProviderState.provider}
|
||||
form={form}
|
||||
@@ -518,33 +505,21 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
disabled={isSaving}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
{/* Provider Configuration */}
|
||||
<div className="border-0 border-t border-solid border-border pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowProviderConfig((v) => !v)}
|
||||
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
|
||||
>
|
||||
<CollapsibleSection variant="inline" defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle as="h3">
|
||||
Provider Configuration
|
||||
</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Tune provider-specific behavior like reasoning, tool calling,
|
||||
and web search.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div>
|
||||
<h3 className="m-0 text-sm font-medium text-content-primary">
|
||||
Provider Configuration
|
||||
</h3>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Tune provider-specific behavior like reasoning, tool calling,
|
||||
and web search.
|
||||
</p>
|
||||
</div>
|
||||
{showProviderConfig ? (
|
||||
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
) : (
|
||||
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
)}
|
||||
</button>
|
||||
{showProviderConfig && (
|
||||
<div className="pt-3">
|
||||
<ModelConfigFields
|
||||
provider={selectedProviderState.provider}
|
||||
form={form}
|
||||
@@ -552,33 +527,21 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
disabled={isSaving}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
{/* Advanced */}
|
||||
<div className="border-0 border-t border-solid border-border pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowAdvanced((v) => !v)}
|
||||
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
|
||||
>
|
||||
<div>
|
||||
<h3 className="m-0 text-sm font-medium text-content-primary">
|
||||
Advanced
|
||||
</h3>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Low-level parameters like temperature and penalties. Rarely
|
||||
need changing.
|
||||
</p>
|
||||
</div>
|
||||
{showAdvanced ? (
|
||||
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
) : (
|
||||
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
)}
|
||||
</button>
|
||||
{showAdvanced && (
|
||||
<div className="grid grid-cols-2 gap-3 pt-3 sm:grid-cols-3">
|
||||
<CollapsibleSection variant="inline" defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle as="h3">
|
||||
Advanced
|
||||
</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Low-level parameters like temperature and penalties. Rarely need
|
||||
changing.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="grid grid-cols-2 gap-3 sm:grid-cols-3">
|
||||
<GeneralModelConfigFields
|
||||
provider={selectedProviderState.provider}
|
||||
form={form}
|
||||
@@ -629,10 +592,11 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
</div>
|
||||
<div className="mt-auto py-6">
|
||||
{" "}
|
||||
<hr className="mb-4 border-0 border-t border-solid border-border" />
|
||||
<div className="flex items-center justify-between">
|
||||
{isEditing && editingModel && onDeleteModel ? (
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { expect, userEvent, within } from "storybook/test";
|
||||
import { AdminBadge } from "./AdminBadge";
|
||||
import {
|
||||
CollapsibleSection,
|
||||
CollapsibleSectionContent,
|
||||
CollapsibleSectionDescription,
|
||||
CollapsibleSectionHeader,
|
||||
CollapsibleSectionTitle,
|
||||
} from "./CollapsibleSection";
|
||||
|
||||
const meta: Meta<typeof CollapsibleSection> = {
|
||||
title: "pages/AgentsPage/CollapsibleSection",
|
||||
component: CollapsibleSection,
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div style={{ maxWidth: 600 }}>
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof CollapsibleSection>;
|
||||
|
||||
const Placeholder = () => (
|
||||
<p className="text-sm text-content-secondary">Placeholder content</p>
|
||||
);
|
||||
|
||||
export const DefaultOpen: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>Default spend limit</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<CollapsibleSectionDescription>
|
||||
The deployment-wide spending cap.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
const header = canvas.getByRole("button", {
|
||||
name: /Default spend limit/i,
|
||||
});
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
export const Collapsed: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>Default spend limit</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<CollapsibleSectionDescription>
|
||||
The deployment-wide spending cap.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
const header = canvas.getByRole("button", {
|
||||
name: /Default spend limit/i,
|
||||
});
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
export const NoBadge: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Group limits</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Override defaults for groups.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
};
|
||||
|
||||
export const InlineVariant: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection variant="inline" defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle as="h3">Cost Tracking</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Set per-token pricing so Coder can track costs and enforce spending
|
||||
limits.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
const header = canvas.getByRole("button", {
|
||||
name: /Cost Tracking/i,
|
||||
});
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
export const KeyboardToggle: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Keyboard section</CollapsibleSectionTitle>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
const header = canvas.getByRole("button", {
|
||||
name: /Keyboard section/i,
|
||||
});
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
header.focus();
|
||||
|
||||
await userEvent.keyboard("{Enter}");
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.keyboard("{Enter}");
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
await userEvent.keyboard(" ");
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.keyboard(" ");
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,211 @@
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
import { ChevronDownIcon } from "lucide-react";
|
||||
import {
|
||||
type ComponentProps,
|
||||
createContext,
|
||||
type FC,
|
||||
type ReactNode,
|
||||
useContext,
|
||||
useState,
|
||||
} from "react";
|
||||
import {
|
||||
Collapsible,
|
||||
CollapsibleContent,
|
||||
CollapsibleTrigger,
|
||||
} from "#/components/Collapsible/Collapsible";
|
||||
import { cn } from "#/utils/cn";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Variant styles
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const wrapperStyles = cva("", {
|
||||
variants: {
|
||||
variant: {
|
||||
card: "rounded-lg border border-solid border-border-default",
|
||||
inline: "border-0 border-t border-solid border-border-default pt-4",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
});
|
||||
|
||||
const triggerStyles = cva(
|
||||
"flex w-full cursor-pointer items-start justify-between gap-4 border-0 bg-transparent text-left focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-content-link",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
card: "rounded-lg px-6 py-5",
|
||||
inline: "rounded-md p-0",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
},
|
||||
);
|
||||
|
||||
const titleStyles = cva("m-0 text-content-primary", {
|
||||
variants: {
|
||||
variant: {
|
||||
card: "text-lg font-semibold",
|
||||
inline: "text-sm font-medium",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
});
|
||||
|
||||
const descriptionStyles = cva("m-0 text-content-secondary", {
|
||||
variants: {
|
||||
variant: {
|
||||
card: "mt-1 text-sm",
|
||||
inline: "text-xs",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
});
|
||||
|
||||
const contentStyles = cva("", {
|
||||
variants: {
|
||||
variant: {
|
||||
card: "border-0 border-t border-solid border-border-default px-6 pb-5 pt-4",
|
||||
inline: "pt-3",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type Variant = "card" | "inline";
|
||||
|
||||
interface CollapsibleSectionContextValue {
|
||||
variant: Variant;
|
||||
open: boolean;
|
||||
}
|
||||
|
||||
const CollapsibleSectionContext = createContext<CollapsibleSectionContextValue>(
|
||||
{
|
||||
variant: "card",
|
||||
open: true,
|
||||
},
|
||||
);
|
||||
|
||||
const useCollapsibleSection = () => useContext(CollapsibleSectionContext);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Root
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CollapsibleSectionProps extends VariantProps<typeof wrapperStyles> {
|
||||
defaultOpen?: boolean;
|
||||
/** Controlled open state. */
|
||||
open?: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export const CollapsibleSection: FC<CollapsibleSectionProps> = ({
|
||||
defaultOpen,
|
||||
open: controlledOpen,
|
||||
onOpenChange: controlledOnOpenChange,
|
||||
variant = "card",
|
||||
children,
|
||||
}) => {
|
||||
const [uncontrolledOpen, setUncontrolledOpen] = useState(defaultOpen ?? true);
|
||||
const isControlled = controlledOpen !== undefined;
|
||||
const open = isControlled ? controlledOpen : uncontrolledOpen;
|
||||
const onOpenChange = isControlled
|
||||
? controlledOnOpenChange
|
||||
: setUncontrolledOpen;
|
||||
|
||||
return (
|
||||
<CollapsibleSectionContext.Provider
|
||||
value={{ variant: variant ?? "card", open }}
|
||||
>
|
||||
<Collapsible open={open} onOpenChange={onOpenChange}>
|
||||
<div className={wrapperStyles({ variant })}>{children}</div>
|
||||
</Collapsible>
|
||||
</CollapsibleSectionContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Header (trigger)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CollapsibleSectionHeaderProps {
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export const CollapsibleSectionHeader: FC<CollapsibleSectionHeaderProps> = ({
|
||||
children,
|
||||
}) => {
|
||||
const { variant, open } = useCollapsibleSection();
|
||||
|
||||
return (
|
||||
<CollapsibleTrigger className={triggerStyles({ variant })}>
|
||||
<div className="min-w-0 flex-1">{children}</div>
|
||||
<ChevronDownIcon
|
||||
className={cn(
|
||||
"h-4 w-4 shrink-0 text-content-secondary transition-transform duration-200",
|
||||
open && "rotate-180",
|
||||
)}
|
||||
/>
|
||||
</CollapsibleTrigger>
|
||||
);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Title — renders the heading element at whatever level the consumer picks.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type HeadingTag = "h1" | "h2" | "h3" | "h4" | "h5" | "h6";
|
||||
|
||||
interface CollapsibleSectionTitleProps extends ComponentProps<HeadingTag> {
|
||||
as?: HeadingTag;
|
||||
}
|
||||
|
||||
export const CollapsibleSectionTitle: FC<CollapsibleSectionTitleProps> = ({
|
||||
as: Component = "h2",
|
||||
className,
|
||||
...props
|
||||
}) => {
|
||||
const { variant } = useCollapsibleSection();
|
||||
return (
|
||||
<Component className={cn(titleStyles({ variant }), className)} {...props} />
|
||||
);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Description
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const CollapsibleSectionDescription: FC<ComponentProps<"p">> = ({
|
||||
className,
|
||||
...props
|
||||
}) => {
|
||||
const { variant } = useCollapsibleSection();
|
||||
return (
|
||||
<p className={cn(descriptionStyles({ variant }), className)} {...props} />
|
||||
);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Content
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CollapsibleSectionContentProps {
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export const CollapsibleSectionContent: FC<CollapsibleSectionContentProps> = ({
|
||||
children,
|
||||
}) => {
|
||||
const { variant } = useCollapsibleSection();
|
||||
|
||||
return (
|
||||
<CollapsibleContent>
|
||||
<div className={contentStyles({ variant })}>{children}</div>
|
||||
</CollapsibleContent>
|
||||
);
|
||||
};
|
||||
@@ -989,7 +989,7 @@ export const MCPServerAdminPanel: FC<MCPServerAdminPanelProps> = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex min-h-full flex-col space-y-3">
|
||||
<div className="flex min-h-full flex-col space-y-6">
|
||||
{!isFormView ? (
|
||||
<ServerList
|
||||
servers={servers}
|
||||
|
||||
@@ -21,7 +21,9 @@ import {
|
||||
} from "#/components/Table/Table";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { formatCostMicros } from "#/utils/currency";
|
||||
import { AdminBadge } from "./AdminBadge";
|
||||
import { PrStateIcon } from "./GitPanel/GitPanel";
|
||||
import { SectionHeader } from "./SectionHeader";
|
||||
|
||||
dayjs.extend(relativeTime);
|
||||
|
||||
@@ -295,19 +297,16 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
|
||||
const isEmpty = summary.total_prs_created === 0;
|
||||
|
||||
return (
|
||||
<div className="space-y-8">
|
||||
<div className="space-y-6">
|
||||
{/* ── Header ── */}
|
||||
<div className="flex items-end justify-between">
|
||||
<div>
|
||||
<h2 className="m-0 text-xl font-semibold tracking-tight text-content-primary">
|
||||
Pull Request Insights
|
||||
</h2>
|
||||
<p className="m-0 mt-1 text-[13px] text-content-secondary">
|
||||
Code changes detected by Agents.
|
||||
</p>
|
||||
</div>
|
||||
<TimeRangeFilter value={timeRange} onChange={onTimeRangeChange} />
|
||||
</div>
|
||||
<SectionHeader
|
||||
label="Pull Request Insights"
|
||||
description="Code changes detected by Agents."
|
||||
badge={<AdminBadge />}
|
||||
action={
|
||||
<TimeRangeFilter value={timeRange} onChange={onTimeRangeChange} />
|
||||
}
|
||||
/>
|
||||
|
||||
{isEmpty ? (
|
||||
<EmptyState />
|
||||
@@ -355,7 +354,7 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
|
||||
</section>
|
||||
|
||||
{/* ── Model breakdown + Recent PRs side by side ── */}
|
||||
<div className="grid grid-cols-1 gap-6 lg:grid-cols-2">
|
||||
<div className="grid grid-cols-1 gap-6">
|
||||
{/* ── Model performance (simplified) ── */}
|
||||
{by_model.length > 0 && (
|
||||
<section>
|
||||
|
||||
@@ -15,8 +15,8 @@ export const SectionHeader: FC<SectionHeaderProps> = ({
|
||||
}) => (
|
||||
<>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div>
|
||||
<div className="flex w-full items-center gap-2">
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<h2 className="m-0 text-lg font-medium text-content-primary">
|
||||
{label}
|
||||
</h2>
|
||||
|
||||
@@ -78,11 +78,8 @@ export const Default: Story = {
|
||||
expect(canvas.getByText("Claude Sonnet")).toBeInTheDocument();
|
||||
expect(canvas.queryByText("GPT-3.5 (Disabled)")).not.toBeInTheDocument();
|
||||
|
||||
// No footer visible when nothing is dirty
|
||||
expect(
|
||||
canvas.queryByRole("button", { name: /Save/i }),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
// Save button should be disabled when nothing is dirty.
|
||||
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
|
||||
// Type a value to make the footer appear
|
||||
await userEvent.type(gpt4oInput, "95");
|
||||
await waitFor(() => {
|
||||
@@ -163,11 +160,9 @@ export const CancelChanges: Story = {
|
||||
const cancelButton = await canvas.findByRole("button", { name: /Cancel/i });
|
||||
await userEvent.click(cancelButton);
|
||||
|
||||
// Footer should disappear after cancel
|
||||
// Save button should be disabled after cancel.
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
canvas.queryByRole("button", { name: /Save/i }),
|
||||
).not.toBeInTheDocument();
|
||||
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
|
||||
});
|
||||
|
||||
// Input should be cleared back to empty (no override)
|
||||
@@ -194,10 +189,8 @@ export const InvalidDraftShowsFooter: Story = {
|
||||
// Cancel button should be visible so user can discard the edit
|
||||
expect(canvas.getByRole("button", { name: /Cancel/i })).toBeInTheDocument();
|
||||
|
||||
// Save button should NOT be visible (nothing valid to save)
|
||||
expect(
|
||||
canvas.queryByRole("button", { name: /Save/i }),
|
||||
).not.toBeInTheDocument();
|
||||
// Save button should be disabled (nothing valid to save).
|
||||
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableFooter,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
@@ -33,6 +32,7 @@ interface UserCompactionThresholdSettingsProps {
|
||||
thresholdPercent: number,
|
||||
) => Promise<unknown>;
|
||||
onResetThreshold: (modelConfigId: string) => Promise<unknown>;
|
||||
hideHeader?: boolean;
|
||||
}
|
||||
|
||||
const parseThresholdDraft = (value: string): number | null => {
|
||||
@@ -60,6 +60,7 @@ export const UserCompactionThresholdSettings: FC<
|
||||
thresholdsError,
|
||||
onSaveThreshold,
|
||||
onResetThreshold,
|
||||
hideHeader,
|
||||
}) => {
|
||||
const [drafts, setDrafts] = useState<Record<string, string>>({});
|
||||
const [rowErrors, setRowErrors] = useState<Record<string, string>>({});
|
||||
@@ -172,19 +173,23 @@ export const UserCompactionThresholdSettings: FC<
|
||||
};
|
||||
|
||||
const hasAnyPending = pendingModels.size > 0;
|
||||
const hasAnyErrors = Object.keys(rowErrors).length > 0;
|
||||
const hasAnyDrafts = Object.keys(drafts).length > 0;
|
||||
|
||||
const headerBlock = !hideHeader ? (
|
||||
<>
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Context Compaction
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Control when conversation context is automatically summarized for each
|
||||
model. Setting 100% means the conversation will never auto-compact.
|
||||
</p>
|
||||
</>
|
||||
) : null;
|
||||
|
||||
if (isThresholdsLoading) {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Context Compaction
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Control when conversation context is automatically summarized for each
|
||||
model. Setting 100% means the conversation will never auto-compact.
|
||||
</p>
|
||||
{headerBlock}
|
||||
<div className="flex items-center gap-2 text-sm text-content-secondary">
|
||||
<Spinner loading className="h-4 w-4" />
|
||||
Loading thresholds...
|
||||
@@ -196,13 +201,7 @@ export const UserCompactionThresholdSettings: FC<
|
||||
if (thresholdsError != null) {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Context Compaction
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Control when conversation context is automatically summarized for each
|
||||
model. Setting 100% means the conversation will never auto-compact.
|
||||
</p>
|
||||
{headerBlock}
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
{getErrorMessage(
|
||||
thresholdsError,
|
||||
@@ -215,13 +214,7 @@ export const UserCompactionThresholdSettings: FC<
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Context Compaction
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Control when conversation context is automatically summarized for each
|
||||
model. Setting 100% means the conversation will never auto-compact.
|
||||
</p>
|
||||
{headerBlock}
|
||||
{isLoadingModelConfigs ? (
|
||||
<div className="flex items-center gap-2 text-sm text-content-secondary">
|
||||
<Spinner loading className="h-4 w-4" />
|
||||
@@ -240,165 +233,163 @@ export const UserCompactionThresholdSettings: FC<
|
||||
models before compaction thresholds can be set.
|
||||
</p>
|
||||
) : (
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead className="w-0 whitespace-nowrap">Default</TableHead>
|
||||
<TableHead className="w-0 whitespace-nowrap">Threshold</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{enabledModelConfigs.map((modelConfig) => {
|
||||
const existingOverride = overridesByModelID.get(modelConfig.id);
|
||||
const hasOverride = overridesByModelID.has(modelConfig.id);
|
||||
const draftValue =
|
||||
drafts[modelConfig.id] ??
|
||||
(existingOverride !== undefined
|
||||
? String(existingOverride)
|
||||
: "");
|
||||
const parsedDraftValue = parseThresholdDraft(draftValue);
|
||||
const isThisModelMutating = pendingModels.has(modelConfig.id);
|
||||
const isInvalid =
|
||||
draftValue.length > 0 && parsedDraftValue === null;
|
||||
// Only warn when user-typed, not when loaded from
|
||||
// the server.
|
||||
const isDraftDisablingCompaction =
|
||||
draftValue === "100" && drafts[modelConfig.id] !== undefined;
|
||||
const rowError = rowErrors[modelConfig.id];
|
||||
const modelName = modelConfig.display_name || modelConfig.model;
|
||||
|
||||
return (
|
||||
<TableRow key={modelConfig.id}>
|
||||
<TableCell className="text-[13px] font-medium text-content-primary">
|
||||
{modelName}
|
||||
{rowError && (
|
||||
<p
|
||||
aria-live="polite"
|
||||
className="m-0 mt-0.5 text-2xs font-normal text-content-destructive"
|
||||
>
|
||||
{rowError}
|
||||
</p>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="w-0 whitespace-nowrap tabular-nums">
|
||||
{modelConfig.compression_threshold}%
|
||||
</TableCell>
|
||||
<TableCell className="w-0 whitespace-nowrap">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Input
|
||||
aria-label={`${modelName} compaction threshold`}
|
||||
aria-invalid={isInvalid || undefined}
|
||||
type="number"
|
||||
min={0}
|
||||
max={100}
|
||||
inputMode="numeric"
|
||||
className={cn(
|
||||
"h-7 w-16 px-2 text-xs tabular-nums",
|
||||
isInvalid &&
|
||||
"border-content-destructive focus:ring-content-destructive/30",
|
||||
)}
|
||||
value={draftValue}
|
||||
placeholder={String(
|
||||
modelConfig.compression_threshold,
|
||||
)}
|
||||
onChange={(event) => {
|
||||
setDrafts((currentDrafts) => ({
|
||||
...currentDrafts,
|
||||
[modelConfig.id]: event.target.value,
|
||||
}));
|
||||
clearRowError(modelConfig.id);
|
||||
}}
|
||||
disabled={isThisModelMutating}
|
||||
/>
|
||||
</TooltipTrigger>
|
||||
{(isInvalid || isDraftDisablingCompaction) && (
|
||||
<TooltipContent>
|
||||
{isInvalid
|
||||
? "Enter a whole number between 0 and 100."
|
||||
: "Setting 100% will disable auto-compaction for this model."}
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
<span className="text-xs text-content-secondary">%</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="subtle"
|
||||
className={cn(
|
||||
"size-7",
|
||||
hasOverride
|
||||
? "opacity-100"
|
||||
: "pointer-events-none opacity-0",
|
||||
)}
|
||||
aria-label={`Reset ${modelName} to default`}
|
||||
aria-hidden={!hasOverride}
|
||||
tabIndex={hasOverride ? 0 : -1}
|
||||
disabled={isThisModelMutating || !hasOverride}
|
||||
onClick={() => handleReset(modelConfig.id)}
|
||||
>
|
||||
<RotateCcwIcon className="size-3.5" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
{hasOverride && (
|
||||
<TooltipContent>
|
||||
Reset to default (
|
||||
{modelConfig.compression_threshold}%)
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</div>
|
||||
{isInvalid && (
|
||||
<span className="sr-only" aria-live="polite">
|
||||
Enter a whole number between 0 and 100.
|
||||
</span>
|
||||
)}
|
||||
{isDraftDisablingCompaction && (
|
||||
<span className="sr-only" aria-live="polite">
|
||||
Setting 100% will disable auto-compaction for this
|
||||
model.
|
||||
</span>
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
{(dirtyRows.length > 0 || hasAnyErrors || hasAnyDrafts) && (
|
||||
<TableFooter className="bg-transparent">
|
||||
<TableRow className="border-0">
|
||||
<TableCell colSpan={3} className="border-0 p-0">
|
||||
<div className="flex items-center justify-end gap-2 px-3 py-1.5">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={handleCancelAll}
|
||||
disabled={hasAnyPending}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
{dirtyRows.length > 0 && (
|
||||
<Button
|
||||
size="sm"
|
||||
type="button"
|
||||
disabled={hasAnyPending}
|
||||
onClick={handleSaveAll}
|
||||
>
|
||||
{hasAnyPending
|
||||
? "Saving..."
|
||||
: `Save ${dirtyRows.length} ${dirtyRows.length === 1 ? "change" : "changes"}`}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<>
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead className="w-0 whitespace-nowrap">Default</TableHead>
|
||||
<TableHead className="w-0 whitespace-nowrap">
|
||||
Threshold
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableFooter>
|
||||
)}
|
||||
</Table>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{enabledModelConfigs.map((modelConfig) => {
|
||||
const existingOverride = overridesByModelID.get(modelConfig.id);
|
||||
const hasOverride = overridesByModelID.has(modelConfig.id);
|
||||
const draftValue =
|
||||
drafts[modelConfig.id] ??
|
||||
(existingOverride !== undefined
|
||||
? String(existingOverride)
|
||||
: "");
|
||||
const parsedDraftValue = parseThresholdDraft(draftValue);
|
||||
const isThisModelMutating = pendingModels.has(modelConfig.id);
|
||||
const isInvalid =
|
||||
draftValue.length > 0 && parsedDraftValue === null;
|
||||
// Only warn when user-typed, not when loaded from
|
||||
// the server.
|
||||
const isDraftDisablingCompaction =
|
||||
draftValue === "100" && drafts[modelConfig.id] !== undefined;
|
||||
const rowError = rowErrors[modelConfig.id];
|
||||
const modelName = modelConfig.display_name || modelConfig.model;
|
||||
|
||||
return (
|
||||
<TableRow key={modelConfig.id}>
|
||||
<TableCell className="text-[13px] font-medium text-content-primary">
|
||||
{modelName}
|
||||
{rowError && (
|
||||
<p
|
||||
aria-live="polite"
|
||||
className="m-0 mt-0.5 text-2xs font-normal text-content-destructive"
|
||||
>
|
||||
{rowError}
|
||||
</p>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="w-0 whitespace-nowrap tabular-nums">
|
||||
{modelConfig.compression_threshold}%
|
||||
</TableCell>
|
||||
<TableCell className="w-0 whitespace-nowrap">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Input
|
||||
aria-label={`${modelName} compaction threshold`}
|
||||
aria-invalid={isInvalid || undefined}
|
||||
type="number"
|
||||
min={0}
|
||||
max={100}
|
||||
inputMode="numeric"
|
||||
className={cn(
|
||||
"h-7 w-16 px-2 text-xs tabular-nums",
|
||||
isInvalid &&
|
||||
"border-content-destructive focus:ring-content-destructive/30",
|
||||
)}
|
||||
value={draftValue}
|
||||
placeholder={String(
|
||||
modelConfig.compression_threshold,
|
||||
)}
|
||||
onChange={(event) => {
|
||||
setDrafts((currentDrafts) => ({
|
||||
...currentDrafts,
|
||||
[modelConfig.id]: event.target.value,
|
||||
}));
|
||||
clearRowError(modelConfig.id);
|
||||
}}
|
||||
disabled={isThisModelMutating}
|
||||
/>
|
||||
</TooltipTrigger>
|
||||
{(isInvalid || isDraftDisablingCompaction) && (
|
||||
<TooltipContent>
|
||||
{isInvalid
|
||||
? "Enter a whole number between 0 and 100."
|
||||
: "Setting 100% will disable auto-compaction for this model."}
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
<span className="text-xs text-content-secondary">
|
||||
%
|
||||
</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="subtle"
|
||||
className={cn(
|
||||
"size-7",
|
||||
hasOverride
|
||||
? "opacity-100"
|
||||
: "pointer-events-none opacity-0",
|
||||
)}
|
||||
aria-label={`Reset ${modelName} to default`}
|
||||
aria-hidden={!hasOverride}
|
||||
tabIndex={hasOverride ? 0 : -1}
|
||||
disabled={isThisModelMutating || !hasOverride}
|
||||
onClick={() => handleReset(modelConfig.id)}
|
||||
>
|
||||
<RotateCcwIcon className="size-3.5" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
{hasOverride && (
|
||||
<TooltipContent>
|
||||
Reset to default (
|
||||
{modelConfig.compression_threshold}%)
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</div>
|
||||
{isInvalid && (
|
||||
<span className="sr-only" aria-live="polite">
|
||||
Enter a whole number between 0 and 100.
|
||||
</span>
|
||||
)}
|
||||
{isDraftDisablingCompaction && (
|
||||
<span className="sr-only" aria-live="polite">
|
||||
Setting 100% will disable auto-compaction for this
|
||||
model.
|
||||
</span>
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
<div className="flex justify-end gap-2 pt-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={handleCancelAll}
|
||||
disabled={hasAnyPending}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="button"
|
||||
disabled={hasAnyPending || dirtyRows.length === 0}
|
||||
onClick={handleSaveAll}
|
||||
>
|
||||
{hasAnyPending
|
||||
? "Saving..."
|
||||
: dirtyRows.length > 0
|
||||
? `Save ${dirtyRows.length} ${dirtyRows.length === 1 ? "change" : "changes"}`
|
||||
: "Save"}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import type { DateRangeValue } from "#/components/DateRangePicker/DateRangePicker";
|
||||
|
||||
/**
|
||||
* Returns true when the given date falls exactly on midnight
|
||||
* (00:00:00.000). Date-range pickers use midnight of the *following*
|
||||
* day as the exclusive upper bound for a full-day selection. Detecting
|
||||
* this lets call sites subtract 1 ms (or 1 day) so the UI shows the
|
||||
* inclusive end date instead.
|
||||
*/
|
||||
function isMidnight(date: Date): boolean {
|
||||
return (
|
||||
date.getHours() === 0 &&
|
||||
date.getMinutes() === 0 &&
|
||||
date.getSeconds() === 0 &&
|
||||
date.getMilliseconds() === 0
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* When the user picks an explicit date range whose end boundary is
|
||||
* midnight of the following day, adjust it by −1 ms so the
|
||||
* DateRangePicker highlights the inclusive end date.
|
||||
*/
|
||||
export function toInclusiveDateRange(
|
||||
dateRange: DateRangeValue,
|
||||
hasExplicitDateRange: boolean,
|
||||
): DateRangeValue {
|
||||
if (hasExplicitDateRange && isMidnight(dateRange.endDate)) {
|
||||
return {
|
||||
startDate: dateRange.startDate,
|
||||
endDate: new Date(dateRange.endDate.getTime() - 1),
|
||||
};
|
||||
}
|
||||
return dateRange;
|
||||
}
|
||||
@@ -71,7 +71,6 @@ describe("AuditPage", () => {
|
||||
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog, MockAuditLog2],
|
||||
count: 2,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
// When
|
||||
@@ -91,7 +90,6 @@ describe("AuditPage", () => {
|
||||
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
await renderPage();
|
||||
@@ -116,7 +114,6 @@ describe("AuditPage", () => {
|
||||
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
await renderPage();
|
||||
@@ -143,11 +140,9 @@ describe("AuditPage", () => {
|
||||
|
||||
describe("Filtering", () => {
|
||||
it("filters by URL", async () => {
|
||||
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
const getAuditLogsSpy = vi
|
||||
.spyOn(API, "getAuditLogs")
|
||||
.mockResolvedValue({ audit_logs: [MockAuditLog], count: 1 });
|
||||
|
||||
const query = "resource_type:workspace action:create";
|
||||
await renderPage({ filter: query });
|
||||
@@ -178,29 +173,4 @@ describe("AuditPage", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Capped count", () => {
|
||||
it("shows capped count indicator and navigates to next page with correct offset", async () => {
|
||||
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog, MockAuditLog2],
|
||||
count: 2001,
|
||||
count_cap: 2000,
|
||||
});
|
||||
|
||||
const user = userEvent.setup();
|
||||
await renderPage();
|
||||
|
||||
await screen.findByText(/2,000\+/);
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /next page/i }));
|
||||
|
||||
await waitFor(() =>
|
||||
expect(API.getAuditLogs).toHaveBeenLastCalledWith<[AuditLogsRequest]>({
|
||||
limit: DEFAULT_RECORDS_PER_PAGE,
|
||||
offset: DEFAULT_RECORDS_PER_PAGE,
|
||||
q: "",
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -69,7 +69,6 @@ describe("ConnectionLogPage", () => {
|
||||
MockDisconnectedSSHConnectionLog,
|
||||
],
|
||||
count: 2,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
// When
|
||||
@@ -96,7 +95,6 @@ describe("ConnectionLogPage", () => {
|
||||
.mockResolvedValue({
|
||||
connection_logs: [MockConnectedSSHConnectionLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
const query = "type:ssh status:ongoing";
|
||||
|
||||
+42
-50
@@ -12,8 +12,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
@@ -465,8 +463,6 @@ type Conn struct {
|
||||
|
||||
trafficStats *connstats.Statistics
|
||||
lastNetInfo *tailcfg.NetInfo
|
||||
|
||||
awaitReachableGroup singleflight.Group
|
||||
}
|
||||
|
||||
func (c *Conn) GetNetInfo() *tailcfg.NetInfo {
|
||||
@@ -603,60 +599,56 @@ func (c *Conn) DERPMap() *tailcfg.DERPMap {
|
||||
// address is reachable. It's the callers responsibility to provide
|
||||
// a timeout, otherwise this function will block forever.
|
||||
func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
|
||||
result, _, _ := c.awaitReachableGroup.Do(ip.String(), func() (interface{}, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel() // Cancel all pending pings on exit.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel() // Cancel all pending pings on exit.
|
||||
|
||||
completedCtx, completed := context.WithCancel(context.Background())
|
||||
defer completed()
|
||||
completedCtx, completed := context.WithCancel(context.Background())
|
||||
defer completed()
|
||||
|
||||
run := func() {
|
||||
// Safety timeout, initially we'll have around 10-20 goroutines
|
||||
// running in parallel. The exponential backoff will converge
|
||||
// around ~1 ping / 30s, this means we'll have around 10-20
|
||||
// goroutines pending towards the end as well.
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
run := func() {
|
||||
// Safety timeout, initially we'll have around 10-20 goroutines
|
||||
// running in parallel. The exponential backoff will converge
|
||||
// around ~1 ping / 30s, this means we'll have around 10-20
|
||||
// goroutines pending towards the end as well.
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// For reachability, we use TSMP ping, which pings at the IP layer,
|
||||
// and therefore requires that wireguard and the netstack are up.
|
||||
// If we don't wait for wireguard to be up, we could miss a
|
||||
// handshake, and it might take 5 seconds for the handshake to be
|
||||
// retried. A 5s initial round trip can set us up for poor TCP
|
||||
// performance, since the initial round-trip-time sets the initial
|
||||
// retransmit timeout.
|
||||
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
|
||||
if err == nil {
|
||||
completed()
|
||||
}
|
||||
// For reachability, we use TSMP ping, which pings at the IP layer, and
|
||||
// therefore requires that wireguard and the netstack are up. If we
|
||||
// don't wait for wireguard to be up, we could miss a handshake, and it
|
||||
// might take 5 seconds for the handshake to be retried. A 5s initial
|
||||
// round trip can set us up for poor TCP performance, since the initial
|
||||
// round-trip-time sets the initial retransmit timeout.
|
||||
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
|
||||
if err == nil {
|
||||
completed()
|
||||
}
|
||||
}
|
||||
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0
|
||||
eb.InitialInterval = 50 * time.Millisecond
|
||||
eb.MaxInterval = 5 * time.Second
|
||||
// Consume the first interval since
|
||||
// we'll fire off a ping immediately.
|
||||
_ = eb.NextBackOff()
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0
|
||||
eb.InitialInterval = 50 * time.Millisecond
|
||||
eb.MaxInterval = 30 * time.Second
|
||||
// Consume the first interval since
|
||||
// we'll fire off a ping immediately.
|
||||
_ = eb.NextBackOff()
|
||||
|
||||
t := backoff.NewTicker(eb)
|
||||
defer t.Stop()
|
||||
t := backoff.NewTicker(eb)
|
||||
defer t.Stop()
|
||||
|
||||
go run()
|
||||
for {
|
||||
select {
|
||||
case <-completedCtx.Done():
|
||||
return true, nil
|
||||
case <-t.C:
|
||||
// Pings can take a while, so we can run multiple
|
||||
// in parallel to return ASAP.
|
||||
go run()
|
||||
case <-ctx.Done():
|
||||
return false, nil
|
||||
}
|
||||
go run()
|
||||
for {
|
||||
select {
|
||||
case <-completedCtx.Done():
|
||||
return true
|
||||
case <-t.C:
|
||||
// Pings can take a while, so we can run multiple
|
||||
// in parallel to return ASAP.
|
||||
go run()
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
})
|
||||
return result.(bool)
|
||||
}
|
||||
}
|
||||
|
||||
// Closed is a channel that ends when the connection has
|
||||
|
||||
Reference in New Issue
Block a user