Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e85be9f42e | |||
| cfd7730194 | |||
| 1937ada0cd | |||
| d64cd6415d | |||
| c1851d9453 | |||
| 8f73453681 | |||
| 165db3d31c | |||
| 1bd1516fd1 | |||
| 81ba35a987 | |||
| 53d63cf8e9 | |||
| 4213a43b53 | |||
| 5453a6c6d6 | |||
| 21c08a37d7 | |||
| 2bd261fbbf | |||
| cffc68df58 | |||
| 6e5335df1e | |||
| 16265e834e | |||
| 565a15bc9b | |||
| 76a2cb1af5 | |||
| 684f21740d | |||
| 86ca61d6ca | |||
| f0521cfa3c | |||
| 0c5d189aff | |||
| d7c8213eee | |||
| 63924ac687 |
Generated
+6
@@ -14175,6 +14175,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -14496,6 +14499,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
Generated
+6
@@ -12739,6 +12739,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13039,6 +13042,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
+8
-1
@@ -26,6 +26,11 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Limit the count query to avoid a slow sequential scan due to joins
|
||||
// on a large table. Set to 0 to disable capping (but also see the note
|
||||
// in the SQL query).
|
||||
const auditLogCountCap = 2000
|
||||
|
||||
// @Summary Get audit logs
|
||||
// @ID get-audit-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -66,7 +71,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
countFilter.Username = ""
|
||||
}
|
||||
|
||||
// Use the same filters to count the number of audit logs
|
||||
countFilter.CountCap = auditLogCountCap
|
||||
count, err := api.Database.CountAuditLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -81,6 +86,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: []codersdk.AuditLog{},
|
||||
Count: 0,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -98,6 +104,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: api.convertAuditLogs(ctx, dblogs),
|
||||
Count: count,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2155,17 +2155,12 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecret(ctx, id)
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -4128,19 +4123,6 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
|
||||
return q.db.GetUserNotificationPreferences(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
@@ -5524,7 +5506,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
|
||||
return q.db.ListUserChatCompactionThresholds(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
return nil, err
|
||||
@@ -5532,6 +5514,16 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
|
||||
return q.db.ListUserSecrets(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
// This query returns decrypted secret values and must only be called
|
||||
// from system contexts (provisioner, agent manifest). REST API
|
||||
// handlers should use ListUserSecrets (metadata only).
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListUserSecretsWithValues(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
@@ -5782,15 +5774,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
// The batch heartbeat is a system-level operation filtered by
|
||||
// worker_id. Authorization is enforced by the AsChatd context
|
||||
// at the call site rather than per-row, because checking each
|
||||
// row individually would defeat the purpose of batching.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
return q.db.UpdateChatHeartbeats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
|
||||
@@ -6632,17 +6624,12 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
|
||||
return q.db.UpdateUserRoles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, arg.ID)
|
||||
if err != nil {
|
||||
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return q.db.UpdateUserSecret(ctx, arg)
|
||||
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
|
||||
|
||||
@@ -842,15 +842,15 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
resultID := uuid.New()
|
||||
arg := database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{resultID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
|
||||
}))
|
||||
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -5346,19 +5346,20 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns([]database.ListUserSecretsRow{row})
|
||||
}))
|
||||
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceSystem, policy.ActionRead).
|
||||
Returns([]database.UserSecret{secret})
|
||||
}))
|
||||
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
@@ -5370,22 +5371,21 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
|
||||
Returns(ret)
|
||||
}))
|
||||
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
|
||||
arg := database.UpdateUserSecretParams{ID: secret.ID}
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(secret, policy.ActionUpdate).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
|
||||
Returns(updated)
|
||||
}))
|
||||
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
Returns()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -1597,6 +1597,7 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
|
||||
Name: takeFirst(seed.Name, "secret-name"),
|
||||
Description: takeFirst(seed.Description, "secret description"),
|
||||
Value: takeFirst(seed.Value, "secret value"),
|
||||
ValueKeyID: seed.ValueKeyID,
|
||||
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
|
||||
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
|
||||
})
|
||||
|
||||
@@ -712,11 +712,11 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -2624,14 +2624,6 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
|
||||
@@ -3920,7 +3912,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecrets(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
|
||||
@@ -3928,6 +3920,14 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
|
||||
@@ -4136,11 +4136,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
|
||||
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4696,11 +4696,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
|
||||
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
|
||||
@@ -1199,18 +1199,18 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
|
||||
@@ -4907,21 +4907,6 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserSecret mocks base method.
|
||||
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserSecret indicates an expected call of GetUserSecret.
|
||||
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
|
||||
}
|
||||
|
||||
// GetUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7412,10 +7397,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
|
||||
}
|
||||
|
||||
// ListUserSecrets mocks base method.
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret0, _ := ret[0].([]database.ListUserSecretsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -7426,6 +7411,21 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues mocks base method.
|
||||
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
|
||||
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
|
||||
}
|
||||
|
||||
// ListWorkspaceAgentPortShares mocks base method.
|
||||
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7835,19 +7835,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
// UpdateChatHeartbeats mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
|
||||
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLabelsByID mocks base method.
|
||||
@@ -8854,19 +8854,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserSecret mocks base method.
|
||||
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// UpdateUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
|
||||
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserStatus mocks base method.
|
||||
|
||||
@@ -584,6 +584,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -720,6 +721,7 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
@@ -145,5 +145,13 @@ func extractWhereClause(query string) string {
|
||||
// Remove SQL comments
|
||||
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
|
||||
|
||||
// Normalize indentation so subquery wrapping doesn't cause
|
||||
// mismatches.
|
||||
lines := strings.Split(whereClause, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = strings.TrimLeft(line, " \t")
|
||||
}
|
||||
whereClause = strings.Join(lines, "\n")
|
||||
|
||||
return strings.TrimSpace(whereClause)
|
||||
}
|
||||
|
||||
@@ -81,8 +81,8 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
|
||||
}
|
||||
|
||||
func (q *msgQueue) run() {
|
||||
var batch [maxDrainBatch]msgOrErr
|
||||
for {
|
||||
// wait until there is something on the queue or we are closed
|
||||
q.cond.L.Lock()
|
||||
for q.size == 0 && !q.closed {
|
||||
q.cond.Wait()
|
||||
@@ -91,28 +91,32 @@ func (q *msgQueue) run() {
|
||||
q.cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
item := q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
q.size--
|
||||
// Drain up to maxDrainBatch items while holding the lock.
|
||||
n := min(q.size, maxDrainBatch)
|
||||
for i := range n {
|
||||
batch[i] = q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
}
|
||||
q.size -= n
|
||||
q.cond.L.Unlock()
|
||||
|
||||
// process item without holding lock
|
||||
if item.err == nil {
|
||||
// real message
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
// Dispatch each message individually without holding the lock.
|
||||
for i := range n {
|
||||
item := batch[i]
|
||||
if item.err == nil {
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
q.le(q.ctx, nil, item.err)
|
||||
}
|
||||
// unhittable
|
||||
continue
|
||||
}
|
||||
// if the listener wants errors, send it.
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, nil, item.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -233,6 +237,12 @@ type PGPubsub struct {
|
||||
// for a subscriber before dropping messages.
|
||||
const BufferSize = 2048
|
||||
|
||||
// maxDrainBatch is the maximum number of messages to drain from the ring
|
||||
// buffer per iteration. Batching amortizes the cost of mutex
|
||||
// acquire/release and cond.Wait across many messages, improving drain
|
||||
// throughput during bursts.
|
||||
const maxDrainBatch = 256
|
||||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
|
||||
|
||||
@@ -152,7 +152,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -598,7 +598,6 @@ type sqlcQuerier interface {
|
||||
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
|
||||
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
|
||||
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
|
||||
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
|
||||
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
// GetUserStatusCounts returns the count of users in each status over time.
|
||||
// The time range is inclusively defined by the start_time and end_time parameters.
|
||||
@@ -818,7 +817,13 @@ type sqlcQuerier interface {
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
// Returns metadata only (no value or value_key_id) for the
|
||||
// REST API list and get endpoints.
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
|
||||
// Returns all columns including the secret value. Used by the
|
||||
// provisioner (build-time injection) and the agent manifest
|
||||
// (runtime injection).
|
||||
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
|
||||
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
|
||||
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
|
||||
@@ -870,9 +875,11 @@ type sqlcQuerier interface {
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
// Updates the cached injected context parts (AGENTS.md +
|
||||
// skills) on the chat row. Called only when context changes
|
||||
@@ -955,7 +962,7 @@ type sqlcQuerier interface {
|
||||
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
|
||||
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
|
||||
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
|
||||
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
|
||||
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
|
||||
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
|
||||
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
|
||||
|
||||
@@ -7339,13 +7339,7 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, secretID, createdSecret.ID)
|
||||
|
||||
// 2. READ by ID
|
||||
readSecret, err := db.GetUserSecret(ctx, createdSecret.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readSecret.Name)
|
||||
|
||||
// 3. READ by UserID and Name
|
||||
// 2. READ by UserID and Name
|
||||
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
@@ -7353,33 +7347,43 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
|
||||
|
||||
// 4. LIST
|
||||
// 3. LIST (metadata only)
|
||||
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 1)
|
||||
assert.Equal(t, createdSecret.ID, secrets[0].ID)
|
||||
|
||||
// 5. UPDATE
|
||||
updateParams := database.UpdateUserSecretParams{
|
||||
ID: createdSecret.ID,
|
||||
Description: "Updated workflow description",
|
||||
Value: "updated-workflow-value",
|
||||
EnvName: "UPDATED_WORKFLOW_ENV",
|
||||
FilePath: "/updated/workflow/path",
|
||||
// 4. LIST with values
|
||||
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secretsWithValues, 1)
|
||||
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
|
||||
|
||||
// 5. UPDATE (partial - only description)
|
||||
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
UpdateDescription: true,
|
||||
Description: "Updated workflow description",
|
||||
}
|
||||
|
||||
updatedSecret, err := db.UpdateUserSecret(ctx, updateParams)
|
||||
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
|
||||
assert.Equal(t, "updated-workflow-value", updatedSecret.Value)
|
||||
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
|
||||
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
||||
|
||||
// 6. DELETE
|
||||
err = db.DeleteUserSecret(ctx, createdSecret.ID)
|
||||
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
_, err = db.GetUserSecret(ctx, createdSecret.ID)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no rows in result set")
|
||||
|
||||
@@ -7449,9 +7453,13 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
})
|
||||
|
||||
// Verify both secrets exist
|
||||
_, err = db.GetUserSecret(ctx, secret1.ID)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret1.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.GetUserSecret(ctx, secret2.ID)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret2.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -7474,14 +7482,14 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
// Create secrets for users
|
||||
user1Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user1.ID,
|
||||
Name: "user1-secret",
|
||||
Description: "User 1's secret",
|
||||
Value: "user1-value",
|
||||
})
|
||||
|
||||
user2Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user2.ID,
|
||||
Name: "user2-secret",
|
||||
Description: "User 2's secret",
|
||||
@@ -7491,7 +7499,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
subject rbac.Subject
|
||||
secretID uuid.UUID
|
||||
lookupUserID uuid.UUID
|
||||
lookupName string
|
||||
expectedAccess bool
|
||||
}{
|
||||
{
|
||||
@@ -7501,7 +7510,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user1Secret.ID,
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
expectedAccess: true,
|
||||
},
|
||||
{
|
||||
@@ -7511,7 +7521,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user2Secret.ID,
|
||||
lookupUserID: user2.ID,
|
||||
lookupName: "user2-secret",
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7521,7 +7532,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user1Secret.ID,
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7531,7 +7543,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user1Secret.ID,
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
expectedAccess: false,
|
||||
},
|
||||
}
|
||||
@@ -7543,8 +7556,10 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
|
||||
authCtx := dbauthz.As(ctx, tc.subject)
|
||||
|
||||
// Test GetUserSecret
|
||||
_, err := authDB.GetUserSecret(authCtx, tc.secretID)
|
||||
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: tc.lookupUserID,
|
||||
Name: tc.lookupName,
|
||||
})
|
||||
|
||||
if tc.expectedAccess {
|
||||
require.NoError(t, err, "expected access to be granted")
|
||||
|
||||
+362
-259
@@ -2275,93 +2275,105 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
|
||||
}
|
||||
|
||||
const countAuditLogs = `-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF($13::int, 0) + 1
|
||||
) AS limited_count
|
||||
`
|
||||
|
||||
type CountAuditLogsParams struct {
|
||||
@@ -2377,6 +2389,7 @@ type CountAuditLogsParams struct {
|
||||
DateTo time.Time `db:"date_to" json:"date_to"`
|
||||
BuildReason string `db:"build_reason" json:"build_reason"`
|
||||
RequestID uuid.UUID `db:"request_id" json:"request_id"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) {
|
||||
@@ -2393,6 +2406,7 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
@@ -6601,30 +6615,49 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatHeartbeat = `-- name: UpdateChatHeartbeat :execrows
|
||||
const updateChatHeartbeats = `-- name: UpdateChatHeartbeats :many
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = NOW()
|
||||
heartbeat_at = $1::timestamptz
|
||||
WHERE
|
||||
id = $1::uuid
|
||||
AND worker_id = $2::uuid
|
||||
id = ANY($2::uuid[])
|
||||
AND worker_id = $3::uuid
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
type UpdateChatHeartbeatParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
type UpdateChatHeartbeatsParams struct {
|
||||
Now time.Time `db:"now" json:"now"`
|
||||
IDs []uuid.UUID `db:"ids" json:"ids"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
}
|
||||
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) {
|
||||
result, err := q.db.ExecContext(ctx, updateChatHeartbeat, arg.ID, arg.WorkerID)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
func (q *sqlQuerier) UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats, arg.Now, pq.Array(arg.IDs), arg.WorkerID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return nil, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
defer rows.Close()
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateChatLabelsByID = `-- name: UpdateChatLabelsByID :one
|
||||
@@ -7571,110 +7604,113 @@ func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUps
|
||||
}
|
||||
|
||||
const countConnectionLogs = `-- name: CountConnectionLogs :one
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF($14::int, 0) + 1
|
||||
) AS limited_count
|
||||
`
|
||||
|
||||
type CountConnectionLogsParams struct {
|
||||
@@ -7691,6 +7727,7 @@ type CountConnectionLogsParams struct {
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"`
|
||||
Status string `db:"status" json:"status"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) {
|
||||
@@ -7708,6 +7745,7 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
@@ -22601,21 +22639,30 @@ INSERT INTO user_secrets (
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
value_key_id,
|
||||
env_name,
|
||||
file_path
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8
|
||||
) RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
`
|
||||
|
||||
type CreateUserSecretParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretParams) (UserSecret, error) {
|
||||
@@ -22625,6 +22672,7 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
|
||||
arg.Name,
|
||||
arg.Description,
|
||||
arg.Value,
|
||||
arg.ValueKeyID,
|
||||
arg.EnvName,
|
||||
arg.FilePath,
|
||||
)
|
||||
@@ -22644,41 +22692,24 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteUserSecret = `-- name: DeleteUserSecret :exec
|
||||
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
DELETE FROM user_secrets
|
||||
WHERE id = $1
|
||||
WHERE user_id = $1 AND name = $2
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserSecret, id)
|
||||
type DeleteUserSecretByUserIDAndNameParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
const getUserSecret = `-- name: GetUserSecret :one
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserSecret, id)
|
||||
var i UserSecret
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.Value,
|
||||
&i.EnvName,
|
||||
&i.FilePath,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ValueKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserSecretByUserIDAndName = `-- name: GetUserSecretByUserIDAndName :one
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2
|
||||
`
|
||||
|
||||
@@ -22706,17 +22737,76 @@ func (q *sqlQuerier) GetUserSecretByUserIDAndName(ctx context.Context, arg GetUs
|
||||
}
|
||||
|
||||
const listUserSecrets = `-- name: ListUserSecrets :many
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
|
||||
SELECT
|
||||
id, user_id, name, description,
|
||||
env_name, file_path,
|
||||
created_at, updated_at
|
||||
FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
ORDER BY name ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
|
||||
type ListUserSecretsRow struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
// Returns metadata only (no value or value_key_id) for the
|
||||
// REST API list and get endpoints.
|
||||
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listUserSecrets, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListUserSecretsRow
|
||||
for rows.Next() {
|
||||
var i ListUserSecretsRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.EnvName,
|
||||
&i.FilePath,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listUserSecretsWithValues = `-- name: ListUserSecretsWithValues :many
|
||||
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
ORDER BY name ASC
|
||||
`
|
||||
|
||||
// Returns all columns including the secret value. Used by the
|
||||
// provisioner (build-time injection) and the agent manifest
|
||||
// (runtime injection).
|
||||
func (q *sqlQuerier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listUserSecretsWithValues, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []UserSecret
|
||||
for rows.Next() {
|
||||
var i UserSecret
|
||||
@@ -22745,33 +22835,46 @@ func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]U
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateUserSecret = `-- name: UpdateUserSecret :one
|
||||
const updateUserSecretByUserIDAndName = `-- name: UpdateUserSecretByUserIDAndName :one
|
||||
UPDATE user_secrets
|
||||
SET
|
||||
description = $2,
|
||||
value = $3,
|
||||
env_name = $4,
|
||||
file_path = $5,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $1
|
||||
value = CASE WHEN $1::bool THEN $2 ELSE value END,
|
||||
value_key_id = CASE WHEN $1::bool THEN $3 ELSE value_key_id END,
|
||||
description = CASE WHEN $4::bool THEN $5 ELSE description END,
|
||||
env_name = CASE WHEN $6::bool THEN $7 ELSE env_name END,
|
||||
file_path = CASE WHEN $8::bool THEN $9 ELSE file_path END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = $10 AND name = $11
|
||||
RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
|
||||
`
|
||||
|
||||
type UpdateUserSecretParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
type UpdateUserSecretByUserIDAndNameParams struct {
|
||||
UpdateValue bool `db:"update_value" json:"update_value"`
|
||||
Value string `db:"value" json:"value"`
|
||||
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
|
||||
UpdateDescription bool `db:"update_description" json:"update_description"`
|
||||
Description string `db:"description" json:"description"`
|
||||
UpdateEnvName bool `db:"update_env_name" json:"update_env_name"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
UpdateFilePath bool `db:"update_file_path" json:"update_file_path"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserSecret,
|
||||
arg.ID,
|
||||
arg.Description,
|
||||
func (q *sqlQuerier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserSecretByUserIDAndName,
|
||||
arg.UpdateValue,
|
||||
arg.Value,
|
||||
arg.ValueKeyID,
|
||||
arg.UpdateDescription,
|
||||
arg.Description,
|
||||
arg.UpdateEnvName,
|
||||
arg.EnvName,
|
||||
arg.UpdateFilePath,
|
||||
arg.FilePath,
|
||||
arg.UserID,
|
||||
arg.Name,
|
||||
)
|
||||
var i UserSecret
|
||||
err := row.Scan(
|
||||
|
||||
@@ -149,94 +149,105 @@ VALUES (
|
||||
RETURNING *;
|
||||
|
||||
-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldAuditLogConnectionEvents :exec
|
||||
DELETE FROM audit_logs
|
||||
|
||||
@@ -674,17 +674,20 @@ WHERE
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
|
||||
-- name: UpdateChatHeartbeat :execrows
|
||||
-- Bumps the heartbeat timestamp for a running chat so that other
|
||||
-- replicas know the worker is still alive.
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
-- provided they are still running and owned by the specified
|
||||
-- worker. Returns the IDs that were actually updated so the
|
||||
-- caller can detect stolen or completed chats via set-difference.
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = NOW()
|
||||
heartbeat_at = @now::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
id = ANY(@ids::uuid[])
|
||||
AND worker_id = @worker_id::uuid
|
||||
AND status = 'running'::chat_status;
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id;
|
||||
|
||||
-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
|
||||
@@ -133,111 +133,113 @@ OFFSET
|
||||
@offset_opt;
|
||||
|
||||
-- name: CountConnectionLogs :one
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldConnectionLogs :execrows
|
||||
WITH old_logs AS (
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
-- name: GetUserSecretByUserIDAndName :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2;
|
||||
|
||||
-- name: GetUserSecret :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE id = $1;
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
-- name: ListUserSecrets :many
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
-- Returns metadata only (no value or value_key_id) for the
|
||||
-- REST API list and get endpoints.
|
||||
SELECT
|
||||
id, user_id, name, description,
|
||||
env_name, file_path,
|
||||
created_at, updated_at
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: ListUserSecretsWithValues :many
|
||||
-- Returns all columns including the secret value. Used by the
|
||||
-- provisioner (build-time injection) and the agent manifest
|
||||
-- (runtime injection).
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: CreateUserSecret :one
|
||||
@@ -18,23 +30,32 @@ INSERT INTO user_secrets (
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
value_key_id,
|
||||
env_name,
|
||||
file_path
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7
|
||||
@id,
|
||||
@user_id,
|
||||
@name,
|
||||
@description,
|
||||
@value,
|
||||
@value_key_id,
|
||||
@env_name,
|
||||
@file_path
|
||||
) RETURNING *;
|
||||
|
||||
-- name: UpdateUserSecret :one
|
||||
-- name: UpdateUserSecretByUserIDAndName :one
|
||||
UPDATE user_secrets
|
||||
SET
|
||||
description = $2,
|
||||
value = $3,
|
||||
env_name = $4,
|
||||
file_path = $5,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $1
|
||||
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
|
||||
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
|
||||
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
|
||||
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
|
||||
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecret :exec
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
DELETE FROM user_secrets
|
||||
WHERE id = $1;
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
@@ -298,6 +298,40 @@ neq(input.object.owner, "");
|
||||
ExpectedSQL: p("'' = 'org-id'"),
|
||||
VariableConverter: regosql.ChatConverter(),
|
||||
},
|
||||
{
|
||||
Name: "AuditLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.AuditLogConverter(),
|
||||
},
|
||||
{
|
||||
Name: "ConnectionLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.ConnectionLogConverter(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
|
||||
func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
// Audit logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
func ConnectionLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
// Connection logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ VariableMatcher = astUUIDVar{}
|
||||
_ Node = astUUIDVar{}
|
||||
_ SupportsEquality = astUUIDVar{}
|
||||
)
|
||||
|
||||
// astUUIDVar is a variable that represents a UUID column. Unlike
|
||||
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
|
||||
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
|
||||
// This allows PostgreSQL to use indexes on UUID columns.
|
||||
type astUUIDVar struct {
|
||||
Source RegoSource
|
||||
FieldPath []string
|
||||
ColumnString string
|
||||
}
|
||||
|
||||
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
|
||||
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
|
||||
}
|
||||
|
||||
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
|
||||
|
||||
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||
left, err := RegoVarPath(u.FieldPath, rego)
|
||||
if err == nil && len(left) == 0 {
|
||||
return astUUIDVar{
|
||||
Source: RegoSource(rego.String()),
|
||||
FieldPath: u.FieldPath,
|
||||
ColumnString: u.ColumnString,
|
||||
}, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
|
||||
return u.ColumnString
|
||||
}
|
||||
|
||||
// EqualsSQLString handles equality comparisons for UUID columns.
|
||||
// Rego always produces string literals, so we accept AstString and
|
||||
// cast the literal to ::uuid in the output SQL. This lets PG use
|
||||
// native UUID indexes instead of falling back to text comparisons.
|
||||
// nolint:revive
|
||||
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case AstString:
|
||||
// The other side is a rego string literal like
|
||||
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
|
||||
// that casts the literal to uuid so PG can use indexes:
|
||||
// column = 'val'::uuid
|
||||
// instead of the text-based:
|
||||
// 'val' = COALESCE(column::text, '')
|
||||
s, ok := other.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString, got %T", other)
|
||||
}
|
||||
if s.Value == "" {
|
||||
// Empty string in rego means "no value". Compare the
|
||||
// column against NULL since UUID columns represent
|
||||
// absent values as NULL, not empty strings.
|
||||
op := "IS NULL"
|
||||
if not {
|
||||
op = "IS NOT NULL"
|
||||
}
|
||||
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
|
||||
}
|
||||
return fmt.Sprintf("%s %s '%s'::uuid",
|
||||
u.ColumnString, equalsOp(not), s.Value), nil
|
||||
case astUUIDVar:
|
||||
return basicSQLEquality(cfg, not, u, other), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T",
|
||||
u, equalsOp(not), other)
|
||||
}
|
||||
}
|
||||
|
||||
// ContainedInSQL implements SupportsContainedIn so that a UUID column
|
||||
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
|
||||
// array elements are rego strings, so we cast each to ::uuid.
|
||||
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
|
||||
arr, ok := haystack.(ASTArray)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
|
||||
}
|
||||
|
||||
if len(arr.Value) == 0 {
|
||||
return "false", nil
|
||||
}
|
||||
|
||||
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
|
||||
values := make([]string, 0, len(arr.Value))
|
||||
for _, v := range arr.Value {
|
||||
s, ok := v.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString array element, got %T", v)
|
||||
}
|
||||
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
|
||||
u.ColumnString,
|
||||
strings.Join(values, ",")), nil
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G
|
||||
}
|
||||
|
||||
// Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams.
|
||||
// nolint:exhaustruct // UserID is not obtained from the query parameters.
|
||||
// nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters.
|
||||
countFilter := database.CountAuditLogsParams{
|
||||
RequestID: filter.RequestID,
|
||||
ResourceID: filter.ResourceID,
|
||||
@@ -123,6 +123,7 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey
|
||||
}
|
||||
|
||||
// This MUST be kept in sync with the above
|
||||
// nolint:exhaustruct // CountCap is not obtained from the query parameters.
|
||||
countFilter := database.CountConnectionLogsParams{
|
||||
OrganizationID: filter.OrganizationID,
|
||||
WorkspaceOwner: filter.WorkspaceOwner,
|
||||
|
||||
+37
-19
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -389,6 +390,7 @@ type MultiAgentController struct {
|
||||
// connections to the destination
|
||||
tickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
coordination *tailnet.BasicCoordination
|
||||
sendGroup singleflight.Group
|
||||
|
||||
cancel context.CancelFunc
|
||||
expireOldAgentsDone chan struct{}
|
||||
@@ -418,28 +420,44 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo
|
||||
|
||||
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.connectionTimes[agentID]
|
||||
// If we don't have the agent, subscribe.
|
||||
if !ok {
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
if m.coordination != nil {
|
||||
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
err = xerrors.Errorf("subscribe agent: %w", err)
|
||||
m.coordination.SendErr(err)
|
||||
_ = m.coordination.Client.Close()
|
||||
m.coordination = nil
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
if ok {
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
|
||||
_, err, _ := m.sendGroup.Do(agentID.String(), func() (interface{}, error) {
|
||||
m.mu.Lock()
|
||||
coord := m.coordination
|
||||
m.mu.Unlock()
|
||||
if coord == nil {
|
||||
return nil, xerrors.New("no active coordination")
|
||||
}
|
||||
err := coord.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
m.mu.Unlock()
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Error(context.Background(), "ensureAgent send failed",
|
||||
slog.F("agent_id", agentID), slog.Error(err))
|
||||
return xerrors.Errorf("send AddTunnel: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -730,7 +730,10 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log := s.Logger.With(slog.F("agent_id", appToken.AgentID))
|
||||
log := s.Logger.With(
|
||||
slog.F("agent_id", appToken.AgentID),
|
||||
slog.F("workspace_id", appToken.WorkspaceID),
|
||||
)
|
||||
log.Debug(ctx, "resolved PTY request")
|
||||
|
||||
values := r.URL.Query()
|
||||
@@ -765,19 +768,21 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
go httpapi.HeartbeatClose(ctx, s.Logger, cancel, conn)
|
||||
|
||||
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, log, cancel, conn)
|
||||
|
||||
dialStart := time.Now()
|
||||
|
||||
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err))
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
|
||||
return
|
||||
}
|
||||
defer release()
|
||||
log.Debug(ctx, "dialed workspace agent")
|
||||
log.Debug(ctx, "dialed workspace agent", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
// #nosec G115 - Safe conversion for terminal height/width which are expected to be within uint16 range (0-65535)
|
||||
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
|
||||
arp.Container = container
|
||||
@@ -785,12 +790,12 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
arp.BackendType = backendType
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))
|
||||
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
|
||||
return
|
||||
}
|
||||
defer ptNetConn.Close()
|
||||
log.Debug(ctx, "obtained PTY")
|
||||
log.Debug(ctx, "obtained PTY", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
|
||||
report := newStatsReportFromSignedToken(*appToken)
|
||||
s.collectStats(report)
|
||||
@@ -800,7 +805,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
|
||||
agentssh.Bicopy(ctx, wsNetConn, ptNetConn)
|
||||
log.Debug(ctx, "pty Bicopy finished")
|
||||
log.Debug(ctx, "pty Bicopy finished", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
}
|
||||
|
||||
func (s *Server) collectStats(stats StatsReport) {
|
||||
|
||||
+124
-28
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -151,6 +152,12 @@ type Server struct {
|
||||
inFlightChatStaleAfter time.Duration
|
||||
chatHeartbeatInterval time.Duration
|
||||
|
||||
// heartbeatMu guards heartbeatRegistry.
|
||||
heartbeatMu sync.Mutex
|
||||
// heartbeatRegistry maps chat IDs to their cancel functions
|
||||
// and workspace state for the centralized heartbeat loop.
|
||||
heartbeatRegistry map[uuid.UUID]*heartbeatEntry
|
||||
|
||||
// wakeCh is signaled by SendMessage, EditMessage, CreateChat,
|
||||
// and PromoteQueued so the run loop calls processOnce
|
||||
// immediately instead of waiting for the next ticker.
|
||||
@@ -706,6 +713,17 @@ type chatStreamState struct {
|
||||
bufferRetainedAt time.Time
|
||||
}
|
||||
|
||||
// heartbeatEntry tracks a single chat's cancel function and workspace
|
||||
// state for the centralized heartbeat loop. Instead of spawning a
|
||||
// per-chat goroutine, processChat registers an entry here and the
|
||||
// single heartbeatLoop goroutine handles all chats.
|
||||
type heartbeatEntry struct {
|
||||
cancelWithCause context.CancelCauseFunc
|
||||
chatID uuid.UUID
|
||||
workspaceID uuid.NullUUID
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
// resetDropCounters zeroes the rate-limiting state for both buffer
|
||||
// and subscriber drop warnings. The caller must hold s.mu.
|
||||
func (s *chatStreamState) resetDropCounters() {
|
||||
@@ -2420,8 +2438,8 @@ func New(cfg Config) *Server {
|
||||
clock: clk,
|
||||
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
|
||||
wakeCh: make(chan struct{}, 1),
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
//nolint:gocritic // The chat processor uses a scoped chatd context.
|
||||
ctx = dbauthz.AsChatd(ctx)
|
||||
|
||||
@@ -2461,6 +2479,9 @@ func (p *Server) start(ctx context.Context) {
|
||||
// to handle chats orphaned by crashed or redeployed workers.
|
||||
p.recoverStaleChats(ctx)
|
||||
|
||||
// Single heartbeat loop for all chats on this replica.
|
||||
go p.heartbeatLoop(ctx)
|
||||
|
||||
acquireTicker := p.clock.NewTicker(
|
||||
p.pendingChatAcquireInterval,
|
||||
"chatd",
|
||||
@@ -2730,6 +2751,97 @@ func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) {
|
||||
p.workspaceMCPToolsCache.Delete(chatID)
|
||||
}
|
||||
|
||||
// registerHeartbeat enrolls a chat in the centralized batch
|
||||
// heartbeat loop. Must be called after chatCtx is created.
|
||||
func (p *Server) registerHeartbeat(entry *heartbeatEntry) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
if _, exists := p.heartbeatRegistry[entry.chatID]; exists {
|
||||
p.logger.Warn(context.Background(),
|
||||
"duplicate heartbeat registration, skipping",
|
||||
slog.F("chat_id", entry.chatID))
|
||||
return
|
||||
}
|
||||
p.heartbeatRegistry[entry.chatID] = entry
|
||||
}
|
||||
|
||||
// unregisterHeartbeat removes a chat from the centralized
|
||||
// heartbeat loop when chat processing finishes.
|
||||
func (p *Server) unregisterHeartbeat(chatID uuid.UUID) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
delete(p.heartbeatRegistry, chatID)
|
||||
}
|
||||
|
||||
// heartbeatLoop runs in a single goroutine, issuing one batch
|
||||
// heartbeat query per interval for all registered chats.
|
||||
func (p *Server) heartbeatLoop(ctx context.Context) {
|
||||
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "batch-heartbeat")
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.heartbeatTick(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatTick issues a single batch UPDATE for all running chats
|
||||
// owned by this worker. Chats missing from the result set are
|
||||
// interrupted (stolen by another replica or already completed).
|
||||
func (p *Server) heartbeatTick(ctx context.Context) {
|
||||
// Snapshot the registry under the lock.
|
||||
p.heartbeatMu.Lock()
|
||||
snapshot := maps.Clone(p.heartbeatRegistry)
|
||||
p.heartbeatMu.Unlock()
|
||||
|
||||
if len(snapshot) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect the IDs we believe we own.
|
||||
ids := slices.Collect(maps.Keys(snapshot))
|
||||
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
|
||||
// access for batch-updating heartbeats.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: ids,
|
||||
WorkerID: p.workerID,
|
||||
Now: p.clock.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set of IDs that were successfully updated.
|
||||
updated := make(map[uuid.UUID]struct{}, len(updatedIDs))
|
||||
for _, id := range updatedIDs {
|
||||
updated[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Interrupt registered chats that were not in the result
|
||||
// (stolen by another replica or already completed).
|
||||
for id, entry := range snapshot {
|
||||
if _, ok := updated[id]; !ok {
|
||||
entry.logger.Warn(ctx, "chat not in batch heartbeat result, interrupting")
|
||||
entry.cancelWithCause(chatloop.ErrInterrupted)
|
||||
continue
|
||||
}
|
||||
// Bump workspace usage for surviving chats.
|
||||
newWsID := p.trackWorkspaceUsage(ctx, entry.chatID, entry.workspaceID, entry.logger)
|
||||
// Update workspace ID in the registry for next tick.
|
||||
p.heartbeatMu.Lock()
|
||||
if current, exists := p.heartbeatRegistry[id]; exists {
|
||||
current.workspaceID = newWsID
|
||||
}
|
||||
p.heartbeatMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) Subscribe(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
@@ -3575,33 +3687,17 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Periodically update the heartbeat so other replicas know this
|
||||
// worker is still alive. The goroutine stops when chatCtx is
|
||||
// canceled (either by completion or interruption).
|
||||
go func() {
|
||||
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "heartbeat")
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-chatCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rows, err := p.db.UpdateChatHeartbeat(chatCtx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: p.workerID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(chatCtx, "failed to update chat heartbeat", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if rows == 0 {
|
||||
cancel(chatloop.ErrInterrupted)
|
||||
return
|
||||
}
|
||||
chat.WorkspaceID = p.trackWorkspaceUsage(chatCtx, chat.ID, chat.WorkspaceID, logger)
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Register with the centralized heartbeat loop instead of
|
||||
// running a per-chat goroutine. The loop issues a single batch
|
||||
// UPDATE for all chats on this worker and detects stolen chats
|
||||
// via set-difference.
|
||||
p.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chat.ID,
|
||||
workspaceID: chat.WorkspaceID,
|
||||
logger: logger,
|
||||
})
|
||||
defer p.unregisterHeartbeat(chat.ID)
|
||||
|
||||
// Start buffering stream events BEFORE publishing the running
|
||||
// status. This closes a race where a subscriber sees
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
@@ -2071,6 +2072,7 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
workerID: workerID,
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
configCache: newChatConfigCache(ctx, db, clock),
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
// Publish a stale "pending" notification on the control channel
|
||||
@@ -2133,3 +2135,130 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
require.Equal(t, database.ChatStatusError, finalStatus,
|
||||
"processChat should have reached runChat (error), not been interrupted (waiting)")
|
||||
}
|
||||
|
||||
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
|
||||
// batch heartbeat UPDATE does not return a registered chat's ID
|
||||
// (because another replica stole it or it was completed), the
|
||||
// heartbeat tick cancels that chat's context with ErrInterrupted
|
||||
// while leaving surviving chats untouched.
|
||||
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
workerID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
clock: clock,
|
||||
workerID: workerID,
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
// Create three chats with independent cancel functions.
|
||||
chat1 := uuid.New()
|
||||
chat2 := uuid.New()
|
||||
chat3 := uuid.New()
|
||||
|
||||
_, cancel1 := context.WithCancelCause(ctx)
|
||||
_, cancel2 := context.WithCancelCause(ctx)
|
||||
ctx3, cancel3 := context.WithCancelCause(ctx)
|
||||
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel1,
|
||||
chatID: chat1,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel2,
|
||||
chatID: chat2,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel3,
|
||||
chatID: chat3,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
// The batch UPDATE returns only chat1 and chat2 —
|
||||
// chat3 was "stolen" by another replica.
|
||||
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
require.Equal(t, workerID, params.WorkerID)
|
||||
require.Len(t, params.IDs, 3)
|
||||
// Return only chat1 and chat2 as surviving.
|
||||
return []uuid.UUID{chat1, chat2}, nil
|
||||
},
|
||||
)
|
||||
|
||||
server.heartbeatTick(ctx)
|
||||
|
||||
// chat3's context should be canceled with ErrInterrupted.
|
||||
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
|
||||
"stolen chat should be interrupted")
|
||||
|
||||
// chat3 should have been removed from the registry by
|
||||
// unregister (in production this happens via defer in
|
||||
// processChat). The heartbeat tick itself does not
|
||||
// unregister — it only cancels. Verify the entry is
|
||||
// still present (processChat's defer would clean it up).
|
||||
server.heartbeatMu.Lock()
|
||||
_, chat1Exists := server.heartbeatRegistry[chat1]
|
||||
_, chat2Exists := server.heartbeatRegistry[chat2]
|
||||
_, chat3Exists := server.heartbeatRegistry[chat3]
|
||||
server.heartbeatMu.Unlock()
|
||||
|
||||
require.True(t, chat1Exists, "surviving chat1 should remain registered")
|
||||
require.True(t, chat2Exists, "surviving chat2 should remain registered")
|
||||
require.True(t, chat3Exists,
|
||||
"stolen chat3 should still be in registry (processChat defer removes it)")
|
||||
}
|
||||
|
||||
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
|
||||
// transient database failure causes the tick to log and return
|
||||
// without canceling any registered chats.
|
||||
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
clock: clock,
|
||||
workerID: uuid.New(),
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
chatID := uuid.New()
|
||||
chatCtx, cancel := context.WithCancelCause(ctx)
|
||||
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chatID,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
// Simulate a transient DB error.
|
||||
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
|
||||
nil, xerrors.New("connection reset"),
|
||||
)
|
||||
|
||||
server.heartbeatTick(ctx)
|
||||
|
||||
// Chat should NOT be interrupted — the tick logged and
|
||||
// returned early.
|
||||
require.NoError(t, chatCtx.Err(),
|
||||
"chat context should not be canceled on transient DB error")
|
||||
}
|
||||
|
||||
@@ -474,7 +474,7 @@ func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
|
||||
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
|
||||
}
|
||||
|
||||
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -501,19 +501,24 @@ func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
// Wrong worker_id should return no IDs.
|
||||
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), rows)
|
||||
require.Empty(t, ids)
|
||||
|
||||
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
// Correct worker_id should return the chat's ID.
|
||||
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
WorkerID: workerID,
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), rows)
|
||||
require.Len(t, ids, 1)
|
||||
require.Equal(t, chat.ID, ids[0])
|
||||
}
|
||||
|
||||
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -49,10 +50,11 @@ const connectTimeout = 10 * time.Second
|
||||
const toolCallTimeout = 60 * time.Second
|
||||
|
||||
// ConnectAll connects to all configured MCP servers, discovers
|
||||
// their tools, and returns them as fantasy.AgentTool values. It
|
||||
// skips servers that fail to connect and logs warnings. The
|
||||
// returned cleanup function must be called to close all
|
||||
// connections.
|
||||
// their tools, and returns them as fantasy.AgentTool values.
|
||||
// Tools are sorted by their prefixed name so callers
|
||||
// receive a deterministic order. It skips servers that fail to
|
||||
// connect and logs warnings. The returned cleanup function
|
||||
// must be called to close all connections.
|
||||
func ConnectAll(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
@@ -108,7 +110,9 @@ func ConnectAll(
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
clients = append(clients, mcpClient)
|
||||
if mcpClient != nil {
|
||||
clients = append(clients, mcpClient)
|
||||
}
|
||||
tools = append(tools, serverTools...)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
@@ -119,6 +123,31 @@ func ConnectAll(
|
||||
// discarded.
|
||||
_ = eg.Wait()
|
||||
|
||||
// Sort tools by prefixed name for deterministic ordering
|
||||
// regardless of goroutine completion order. Ties, possible
|
||||
// when the __ separator produces ambiguous prefixed names,
|
||||
// are broken by config ID. Stable prompt construction
|
||||
// depends on consistent tool ordering.
|
||||
slices.SortFunc(tools, func(a, b fantasy.AgentTool) int {
|
||||
// All tools in this slice are mcpToolWrapper values
|
||||
// created by connectOne above, so these checked
|
||||
// assertions should always succeed. The config ID
|
||||
// tiebreaker resolves the __ separator ambiguity
|
||||
// documented at the top of this file.
|
||||
aTool, ok := a.(MCPToolIdentifier)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unexpected tool type %T", a))
|
||||
}
|
||||
bTool, ok := b.(MCPToolIdentifier)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unexpected tool type %T", b))
|
||||
}
|
||||
return cmp.Or(
|
||||
cmp.Compare(a.Info().Name, b.Info().Name),
|
||||
cmp.Compare(aTool.MCPServerConfigID().String(), bTool.MCPServerConfigID().String()),
|
||||
)
|
||||
})
|
||||
|
||||
return tools, cleanup
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,17 @@ func greetTool() mcpserver.ServerTool {
|
||||
}
|
||||
}
|
||||
|
||||
// makeTool returns a ServerTool with the given name and a
|
||||
// no-op handler that always returns "ok".
|
||||
func makeTool(name string) mcpserver.ServerTool {
|
||||
return mcpserver.ServerTool{
|
||||
Tool: mcp.NewTool(name),
|
||||
Handler: func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// makeConfig builds a database.MCPServerConfig suitable for tests.
|
||||
func makeConfig(slug, url string) database.MCPServerConfig {
|
||||
return database.MCPServerConfig{
|
||||
@@ -198,6 +209,121 @@ func TestConnectAll_MultipleServers(t *testing.T) {
|
||||
assert.Contains(t, names, "beta__greet")
|
||||
}
|
||||
|
||||
func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("filtered", ts.URL)
|
||||
cfg.ToolAllowList = []string{"greet"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Empty(t, tools)
|
||||
assert.NotPanics(t, cleanup)
|
||||
}
|
||||
|
||||
func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AcrossServers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts1 := newTestMCPServer(t, makeTool("zebra"))
|
||||
ts2 := newTestMCPServer(t, makeTool("alpha"))
|
||||
ts3 := newTestMCPServer(t, makeTool("middle"))
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{
|
||||
makeConfig("srv3", ts3.URL),
|
||||
makeConfig("srv1", ts1.URL),
|
||||
makeConfig("srv2", ts2.URL),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 3)
|
||||
// Sorted by full prefixed name (slug__tool), so slug
|
||||
// order determines the sequence, not the tool name.
|
||||
assert.Equal(t,
|
||||
[]string{"srv1__zebra", "srv2__alpha", "srv3__middle"},
|
||||
toolNames(tools),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("WithMultiToolServer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
multi := newTestMCPServer(t, makeTool("zeta"), makeTool("beta"))
|
||||
other := newTestMCPServer(t, makeTool("gamma"))
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{
|
||||
makeConfig("zzz", multi.URL),
|
||||
makeConfig("aaa", other.URL),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 3)
|
||||
assert.Equal(t,
|
||||
[]string{"aaa__gamma", "zzz__beta", "zzz__zeta"},
|
||||
toolNames(tools),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("TiebreakByConfigID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts1 := newTestMCPServer(t, makeTool("b__z"))
|
||||
ts2 := newTestMCPServer(t, makeTool("z"))
|
||||
|
||||
// Use fixed UUIDs so the tiebreaker order is
|
||||
// predictable. Both servers produce the same prefixed
|
||||
// name, a__b__z, due to the __ separator ambiguity.
|
||||
cfg1 := makeConfig("a", ts1.URL)
|
||||
cfg1.ID = uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
cfg2 := makeConfig("a__b", ts2.URL)
|
||||
cfg2.ID = uuid.MustParse("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 2)
|
||||
assert.Equal(t, []string{"a__b__z", "a__b__z"}, toolNames(tools))
|
||||
|
||||
id0 := tools[0].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
|
||||
id1 := tools[1].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
|
||||
assert.Equal(t, cfg2.ID, id0, "lower config ID should sort first")
|
||||
assert.Equal(t, cfg1.ID, id1, "higher config ID should sort second")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectAll_AuthHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -212,6 +212,7 @@ type AuditLogsRequest struct {
|
||||
type AuditLogResponse struct {
|
||||
AuditLogs []AuditLog `json:"audit_logs"`
|
||||
Count int64 `json:"count"`
|
||||
CountCap int64 `json:"count_cap"`
|
||||
}
|
||||
|
||||
type CreateTestAuditLogRequest struct {
|
||||
|
||||
@@ -96,6 +96,7 @@ type ConnectionLogsRequest struct {
|
||||
type ConnectionLogResponse struct {
|
||||
ConnectionLogs []ConnectionLog `json:"connection_logs"`
|
||||
Count int64 `json:"count"`
|
||||
CountCap int64 `json:"count_cap"`
|
||||
}
|
||||
|
||||
func (c *Client) ConnectionLogs(ctx context.Context, req ConnectionLogsRequest) (ConnectionLogResponse, error) {
|
||||
|
||||
Generated
+2
-1
@@ -90,7 +90,8 @@ curl -X GET http://coder-server:8080/api/v2/audit?limit=0 \
|
||||
"user_agent": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Generated
+2
-1
@@ -291,7 +291,8 @@ curl -X GET http://coder-server:8080/api/v2/connectionlog?limit=0 \
|
||||
"workspace_owner_username": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Generated
+6
-2
@@ -1740,7 +1740,8 @@
|
||||
"user_agent": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1750,6 +1751,7 @@
|
||||
|--------------|-------------------------------------------------|----------|--------------|-------------|
|
||||
| `audit_logs` | array of [codersdk.AuditLog](#codersdkauditlog) | false | | |
|
||||
| `count` | integer | false | | |
|
||||
| `count_cap` | integer | false | | |
|
||||
|
||||
## codersdk.AuthMethod
|
||||
|
||||
@@ -2173,7 +2175,8 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
||||
"workspace_owner_username": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
@@ -2183,6 +2186,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
||||
|-------------------|-----------------------------------------------------------|----------|--------------|-------------|
|
||||
| `connection_logs` | array of [codersdk.ConnectionLog](#codersdkconnectionlog) | false | | |
|
||||
| `count` | integer | false | | |
|
||||
| `count_cap` | integer | false | | |
|
||||
|
||||
## codersdk.ConnectionLogSSHInfo
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// NOTE: See the auditLogCountCap note.
|
||||
const connectionLogCountCap = 2000
|
||||
|
||||
// @Summary Get connection logs
|
||||
// @ID get-connection-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -49,6 +52,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
// #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range
|
||||
filter.LimitOpt = int32(page.Limit)
|
||||
|
||||
countFilter.CountCap = connectionLogCountCap
|
||||
count, err := api.Database.CountConnectionLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -63,6 +67,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
|
||||
ConnectionLogs: []codersdk.ConnectionLog{},
|
||||
Count: 0,
|
||||
CountCap: connectionLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -80,6 +85,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
|
||||
ConnectionLogs: convertConnectionLogs(dblogs),
|
||||
Count: count,
|
||||
CountCap: connectionLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"golang.org/x/xerrors"
|
||||
gProto "google.golang.org/protobuf/proto"
|
||||
|
||||
@@ -33,9 +34,9 @@ const (
|
||||
eventReadyForHandshake = "tailnet_ready_for_handshake"
|
||||
HeartbeatPeriod = time.Second * 2
|
||||
MissedHeartbeats = 3
|
||||
numQuerierWorkers = 10
|
||||
numQuerierWorkers = 40
|
||||
numBinderWorkers = 10
|
||||
numTunnelerWorkers = 10
|
||||
numTunnelerWorkers = 20
|
||||
numHandshakerWorkers = 5
|
||||
dbMaxBackoff = 10 * time.Second
|
||||
cleanupPeriod = time.Hour
|
||||
@@ -770,6 +771,9 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo
|
||||
|
||||
for k := range m.sent {
|
||||
if _, ok := best[k]; !ok {
|
||||
m.logger.Debug(m.ctx, "peer no longer in best mappings, sending DISCONNECTED",
|
||||
slog.F("peer_id", k),
|
||||
)
|
||||
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
|
||||
Id: agpl.UUIDToByteSlice(k),
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
|
||||
@@ -820,6 +824,8 @@ type querier struct {
|
||||
mu sync.Mutex
|
||||
mappers map[mKey]*mapper
|
||||
healthy bool
|
||||
|
||||
resyncGroup singleflight.Group
|
||||
}
|
||||
|
||||
func newQuerier(ctx context.Context,
|
||||
@@ -958,7 +964,7 @@ func (q *querier) cleanupConn(c *connIO) {
|
||||
|
||||
// maxBatchSize is the maximum number of keys to process in a single batch
|
||||
// query.
|
||||
const maxBatchSize = 50
|
||||
const maxBatchSize = 200
|
||||
|
||||
func (q *querier) peerUpdateWorker() {
|
||||
defer q.wg.Done()
|
||||
@@ -1207,8 +1213,13 @@ func (q *querier) subscribe() {
|
||||
func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
|
||||
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||
q.logger.Warn(q.ctx, "pubsub may have dropped peer updates")
|
||||
// we need to schedule a full resync of peer mappings
|
||||
q.resyncPeerMappings()
|
||||
// Schedule a full resync asynchronously so we don't block the
|
||||
// pubsub drain goroutine. Singleflight coalesces concurrent
|
||||
// resync requests.
|
||||
go q.resyncGroup.Do("resync", func() (any, error) {
|
||||
q.resyncPeerMappings()
|
||||
return nil, nil
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -1234,8 +1245,13 @@ func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
|
||||
func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
|
||||
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||
q.logger.Warn(q.ctx, "pubsub may have dropped tunnel updates")
|
||||
// we need to schedule a full resync of peer mappings
|
||||
q.resyncPeerMappings()
|
||||
// Schedule a full resync asynchronously so we don't block the
|
||||
// pubsub drain goroutine. Singleflight coalesces concurrent
|
||||
// resync requests.
|
||||
go q.resyncGroup.Do("resync", func() (any, error) {
|
||||
q.resyncPeerMappings()
|
||||
return nil, nil
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -1601,6 +1617,10 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
|
||||
// the only mapping available for it. Newer mappings will take
|
||||
// precedence.
|
||||
m.kind = proto.CoordinateResponse_PeerUpdate_LOST
|
||||
h.logger.Debug(h.ctx, "mapping rewritten to LOST due to missed heartbeats",
|
||||
slog.F("peer_id", m.peer),
|
||||
slog.F("coordinator_id", m.coordinator),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,11 +76,11 @@ replace github.com/aquasecurity/trivy => github.com/coder/trivy v0.0.0-202603091
|
||||
// https://github.com/spf13/afero/pull/487
|
||||
replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696
|
||||
|
||||
// Forked from kylecarbs/fantasy (cj/go1.25 branch) which adds:
|
||||
// Forked from coder/fantasy (cj/go1.25 branch) which adds:
|
||||
// 1) Anthropic computer use + thinking effort
|
||||
// 2) Go 1.25 downgrade for Windows CI compat
|
||||
// 3) ibetitsmike/fantasy#4 — skip ephemeral replay items when store=false
|
||||
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8
|
||||
replace charm.land/fantasy => github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8
|
||||
|
||||
replace github.com/charmbracelet/anthropic-sdk-go => github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab
|
||||
|
||||
|
||||
@@ -322,6 +322,8 @@ github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41 h1:SBN/DA63+ZHwu
|
||||
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41/go.mod h1:I9ULxr64UaOSUv7hcb3nX4kowodJCVS7vt7VVJk/kW4=
|
||||
github.com/coder/clistat v1.2.1 h1:P9/10njXMyj5cWzIU5wkRsSy5LVQH49+tcGMsAgWX0w=
|
||||
github.com/coder/clistat v1.2.1/go.mod h1:m7SC0uj88eEERgvF8Kn6+w6XF21BeSr+15f7GoLAw0A=
|
||||
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:n+6v+yT1B6V4oSGPmXFh7mul1E+RzG9rnqp50Vb7M/w=
|
||||
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
|
||||
github.com/coder/flog v1.1.0 h1:kbAes1ai8fIS5OeV+QAnKBQE22ty1jRF/mcAwHpLBa4=
|
||||
github.com/coder/flog v1.1.0/go.mod h1:UQlQvrkJBvnRGo69Le8E24Tcl5SJleAAR7gYEHzAmdQ=
|
||||
github.com/coder/go-httpstat v0.0.0-20230801153223-321c88088322 h1:m0lPZjlQ7vdVpRBPKfYIFlmgevoTkBxB10wv6l2gOaU=
|
||||
@@ -813,8 +815,6 @@ github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:5UMY
|
||||
github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab/go.mod h1:hqlYqR7uPKOKfnNeicUbZp0Ps0GeYFlKYtwh5HGDCx8=
|
||||
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3 h1:Z9/bo5PSeMutpdiKYNt/TTSfGM1Ll0naj3QzYX9VxTc=
|
||||
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3/go.mod h1:BUGjjsD+ndS6eX37YgTchSEG+Jg9Jv1GiZs9sqPqztk=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:fZ0208U3B438fDSHCc/GNioPIyaFqn6eBsQTO61QtrI=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
|
||||
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae h1:xlFZNX4nnxpj9Cf6mTwD3pirXGNtBJ/6COsf9iZmsL0=
|
||||
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
|
||||
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e h1:OP0ZMFeZkUnOzTFRfpuK3m7Kp4fNvC6qN+exwj7aI4M=
|
||||
|
||||
+44
-35
@@ -68,17 +68,17 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
return xerrors.Errorf("detecting branch: %w", err)
|
||||
}
|
||||
|
||||
// Match standard release branches (release/2.32) and RC
|
||||
// branches (release/2.32-rc.0).
|
||||
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)(?:-rc\.(\d+))?$`)
|
||||
// Match release branches (release/X.Y). RCs are tagged
|
||||
// from main, not from release branches.
|
||||
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)$`)
|
||||
m := branchRe.FindStringSubmatch(currentBranch)
|
||||
if m == nil {
|
||||
warnf(w, "Current branch %q is not a release branch (release/X.Y or release/X.Y-rc.N).", currentBranch)
|
||||
warnf(w, "Current branch %q is not a release branch (release/X.Y).", currentBranch)
|
||||
branchInput, err := cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: "Enter the release branch to use (e.g. release/2.21 or release/2.21-rc.0)",
|
||||
Text: "Enter the release branch to use (e.g. release/2.21)",
|
||||
Validate: func(s string) error {
|
||||
if !branchRe.MatchString(s) {
|
||||
return xerrors.New("must be in format release/X.Y or release/X.Y-rc.N (e.g. release/2.21 or release/2.21-rc.0)")
|
||||
return xerrors.New("must be in format release/X.Y (e.g. release/2.21)")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -91,10 +91,6 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
}
|
||||
branchMajor, _ := strconv.Atoi(m[1])
|
||||
branchMinor, _ := strconv.Atoi(m[2])
|
||||
branchRC := -1 // -1 means not an RC branch.
|
||||
if m[3] != "" {
|
||||
branchRC, _ = strconv.Atoi(m[3])
|
||||
}
|
||||
successf(w, "Using release branch: %s", currentBranch)
|
||||
|
||||
// --- Fetch & sync check ---
|
||||
@@ -138,31 +134,44 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
}
|
||||
}
|
||||
|
||||
// changelogBaseRef is the git ref used as the starting point
|
||||
// for release notes generation. When a tag already exists in
|
||||
// this minor series we use it directly. For the first release
|
||||
// on a new minor no matching tag exists, so we compute the
|
||||
// merge-base with the previous minor's release branch instead.
|
||||
// This works even when that branch has no tags yet (it was
|
||||
// just cut and pushed). As a last resort we fall back to the
|
||||
// latest reachable tag from a previous minor.
|
||||
var changelogBaseRef string
|
||||
if prevVersion != nil {
|
||||
changelogBaseRef = prevVersion.String()
|
||||
} else {
|
||||
prevReleaseBranch := fmt.Sprintf("release/%d.%d", branchMajor, branchMinor-1)
|
||||
if err := gitRun("fetch", "--quiet", "origin", prevReleaseBranch); err != nil {
|
||||
warnf(w, "Could not fetch %s: %v", prevReleaseBranch, err)
|
||||
}
|
||||
if mb, mbErr := gitOutput("merge-base", "HEAD", "origin/"+prevReleaseBranch); mbErr == nil && mb != "" {
|
||||
changelogBaseRef = mb
|
||||
infof(w, "Using merge-base with %s as changelog base: %s", prevReleaseBranch, mb[:12])
|
||||
} else {
|
||||
// No previous release branch found; fall back to
|
||||
// the latest reachable tag from a previous minor.
|
||||
for _, t := range mergedTags {
|
||||
if t.Major == branchMajor && t.Minor < branchMinor {
|
||||
changelogBaseRef = t.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var suggested version
|
||||
if prevVersion == nil {
|
||||
infof(w, "No previous release tag found on this branch.")
|
||||
suggested = version{Major: branchMajor, Minor: branchMinor, Patch: 0}
|
||||
if branchRC >= 0 {
|
||||
suggested.Pre = fmt.Sprintf("rc.%d", branchRC)
|
||||
}
|
||||
} else {
|
||||
infof(w, "Previous release tag: %s", prevVersion.String())
|
||||
if branchRC >= 0 {
|
||||
// On an RC branch, suggest the next RC for
|
||||
// the same base version.
|
||||
nextRC := 0
|
||||
if prevVersion.IsRC() {
|
||||
nextRC = prevVersion.rcNumber() + 1
|
||||
}
|
||||
suggested = version{
|
||||
Major: prevVersion.Major,
|
||||
Minor: prevVersion.Minor,
|
||||
Patch: prevVersion.Patch,
|
||||
Pre: fmt.Sprintf("rc.%d", nextRC),
|
||||
}
|
||||
} else {
|
||||
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
|
||||
}
|
||||
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
|
||||
}
|
||||
|
||||
fmt.Fprintln(w)
|
||||
@@ -366,8 +375,8 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
infof(w, "Generating release notes...")
|
||||
|
||||
commitRange := "HEAD"
|
||||
if prevVersion != nil {
|
||||
commitRange = prevVersion.String() + "..HEAD"
|
||||
if changelogBaseRef != "" {
|
||||
commitRange = changelogBaseRef + "..HEAD"
|
||||
}
|
||||
|
||||
commits, err := commitLog(commitRange)
|
||||
@@ -473,16 +482,16 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
|
||||
}
|
||||
if !hasContent {
|
||||
prevStr := "the beginning of time"
|
||||
if prevVersion != nil {
|
||||
prevStr = prevVersion.String()
|
||||
if changelogBaseRef != "" {
|
||||
prevStr = changelogBaseRef
|
||||
}
|
||||
fmt.Fprintf(¬es, "\n_No changes since %s._\n", prevStr)
|
||||
}
|
||||
|
||||
// Compare link.
|
||||
if prevVersion != nil {
|
||||
if changelogBaseRef != "" {
|
||||
fmt.Fprintf(¬es, "\nCompare: [`%s...%s`](https://github.com/%s/%s/compare/%s...%s)\n",
|
||||
prevVersion, newVersion, owner, repo, prevVersion, newVersion)
|
||||
changelogBaseRef, newVersion, owner, repo, changelogBaseRef, newVersion)
|
||||
}
|
||||
|
||||
// Container image.
|
||||
|
||||
Generated
+2
@@ -913,6 +913,7 @@ export interface AuditLog {
|
||||
export interface AuditLogResponse {
|
||||
readonly audit_logs: readonly AuditLog[];
|
||||
readonly count: number;
|
||||
readonly count_cap: number;
|
||||
}
|
||||
|
||||
// From codersdk/audit.go
|
||||
@@ -2269,6 +2270,7 @@ export interface ConnectionLog {
|
||||
export interface ConnectionLogResponse {
|
||||
readonly connection_logs: readonly ConnectionLog[];
|
||||
readonly count: number;
|
||||
readonly count_cap: number;
|
||||
}
|
||||
|
||||
// From codersdk/connectionlog.go
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { Interpolation, Theme } from "@emotion/react";
|
||||
import type { FC, HTMLAttributes } from "react";
|
||||
import type { LogLevel } from "#/api/typesGenerated";
|
||||
import { MONOSPACE_FONT_FAMILY } from "#/theme/constants";
|
||||
import { cn } from "#/utils/cn";
|
||||
|
||||
const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
|
||||
|
||||
@@ -17,65 +16,40 @@ type LogLineProps = {
|
||||
level: LogLevel;
|
||||
} & HTMLAttributes<HTMLPreElement>;
|
||||
|
||||
export const LogLine: FC<LogLineProps> = ({ level, ...divProps }) => {
|
||||
export const LogLine: FC<LogLineProps> = ({ level, className, ...props }) => {
|
||||
return (
|
||||
<pre
|
||||
css={styles.line}
|
||||
className={`${level} ${divProps.className} logs-line`}
|
||||
{...divProps}
|
||||
{...props}
|
||||
className={cn(
|
||||
"logs-line",
|
||||
"m-0 break-all flex items-center h-auto",
|
||||
"text-[13px] text-content-primary font-mono",
|
||||
level === "error" &&
|
||||
"bg-surface-error text-content-error [&_.dashed-line]:bg-border-error",
|
||||
level === "debug" &&
|
||||
"bg-surface-sky text-content-sky [&_.dashed-line]:bg-border-sky",
|
||||
level === "warn" &&
|
||||
"bg-surface-warning text-content-warning [&_.dashed-line]:bg-border-warning",
|
||||
className,
|
||||
)}
|
||||
style={{
|
||||
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = (props) => {
|
||||
return <pre css={styles.prefix} {...props} />;
|
||||
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = ({
|
||||
className,
|
||||
...props
|
||||
}) => {
|
||||
return (
|
||||
<pre
|
||||
className={cn(
|
||||
"select-none m-0 inline-block text-content-secondary mr-6",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const styles = {
|
||||
line: (theme) => ({
|
||||
margin: 0,
|
||||
wordBreak: "break-all",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
fontSize: 13,
|
||||
color: theme.palette.text.primary,
|
||||
fontFamily: MONOSPACE_FONT_FAMILY,
|
||||
height: "auto",
|
||||
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
|
||||
|
||||
"&.error": {
|
||||
backgroundColor: theme.roles.error.background,
|
||||
color: theme.roles.error.text,
|
||||
|
||||
"& .dashed-line": {
|
||||
backgroundColor: theme.roles.error.outline,
|
||||
},
|
||||
},
|
||||
|
||||
"&.debug": {
|
||||
backgroundColor: theme.roles.notice.background,
|
||||
color: theme.roles.notice.text,
|
||||
|
||||
"& .dashed-line": {
|
||||
backgroundColor: theme.roles.notice.outline,
|
||||
},
|
||||
},
|
||||
|
||||
"&.warn": {
|
||||
backgroundColor: theme.roles.warning.background,
|
||||
color: theme.roles.warning.text,
|
||||
|
||||
"& .dashed-line": {
|
||||
backgroundColor: theme.roles.warning.outline,
|
||||
},
|
||||
},
|
||||
}),
|
||||
|
||||
prefix: (theme) => ({
|
||||
userSelect: "none",
|
||||
margin: 0,
|
||||
display: "inline-block",
|
||||
color: theme.palette.text.secondary,
|
||||
marginRight: 24,
|
||||
}),
|
||||
} satisfies Record<string, Interpolation<Theme>>;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { Interpolation, Theme } from "@emotion/react";
|
||||
import dayjs from "dayjs";
|
||||
import type { FC } from "react";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { type Line, LogLine, LogLinePrefix } from "./LogLine";
|
||||
|
||||
export const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
|
||||
@@ -17,7 +17,17 @@ export const Logs: FC<LogsProps> = ({
|
||||
className = "",
|
||||
}) => {
|
||||
return (
|
||||
<div css={styles.root} className={`${className} logs-container`}>
|
||||
<div
|
||||
className={cn(
|
||||
"logs-container",
|
||||
"min-h-40 py-2 rounded-lg overflow-x-auto bg-surface-primary",
|
||||
"[&:not(:last-child)]:border-0",
|
||||
"[&:not(:last-child)]:border-solid",
|
||||
"[&:not(:last-child)]:border-b-border",
|
||||
"[&:not(:last-child)]:rounded-none",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="min-w-fit">
|
||||
{lines.map((line) => (
|
||||
<LogLine key={line.id} level={line.level}>
|
||||
@@ -33,18 +43,3 @@ export const Logs: FC<LogsProps> = ({
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const styles = {
|
||||
root: (theme) => ({
|
||||
minHeight: 156,
|
||||
padding: "8px 0",
|
||||
borderRadius: 8,
|
||||
overflowX: "auto",
|
||||
background: theme.palette.background.default,
|
||||
|
||||
"&:not(:last-child)": {
|
||||
borderBottom: `1px solid ${theme.palette.divider}`,
|
||||
borderRadius: 0,
|
||||
},
|
||||
}),
|
||||
} satisfies Record<string, Interpolation<Theme>>;
|
||||
|
||||
@@ -7,6 +7,7 @@ type PaginationHeaderProps = {
|
||||
limit: number;
|
||||
totalRecords: number | undefined;
|
||||
currentOffsetStart: number | undefined;
|
||||
countIsCapped?: boolean;
|
||||
|
||||
// Temporary escape hatch until Workspaces can be switched over to using
|
||||
// PaginationContainer
|
||||
@@ -18,6 +19,7 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
|
||||
limit,
|
||||
totalRecords,
|
||||
currentOffsetStart,
|
||||
countIsCapped,
|
||||
className,
|
||||
}) => {
|
||||
const theme = useTheme();
|
||||
@@ -52,10 +54,16 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
|
||||
<strong>
|
||||
{(
|
||||
currentOffsetStart +
|
||||
Math.min(limit - 1, totalRecords - currentOffsetStart)
|
||||
(countIsCapped
|
||||
? limit - 1
|
||||
: Math.min(limit - 1, totalRecords - currentOffsetStart))
|
||||
).toLocaleString()}
|
||||
</strong>{" "}
|
||||
of <strong>{totalRecords.toLocaleString()}</strong>{" "}
|
||||
of{" "}
|
||||
<strong>
|
||||
{totalRecords.toLocaleString()}
|
||||
{countIsCapped && "+"}
|
||||
</strong>{" "}
|
||||
{paginationUnitLabel}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -18,6 +18,7 @@ export const mockPaginationResultBase: ResultBase = {
|
||||
limit: 25,
|
||||
hasNextPage: false,
|
||||
hasPreviousPage: false,
|
||||
countIsCapped: false,
|
||||
goToPreviousPage: () => {},
|
||||
goToNextPage: () => {},
|
||||
goToFirstPage: () => {},
|
||||
@@ -33,6 +34,7 @@ export const mockInitialRenderResult: PaginationResult = {
|
||||
hasPreviousPage: false,
|
||||
totalRecords: undefined,
|
||||
totalPages: undefined,
|
||||
countIsCapped: false,
|
||||
};
|
||||
|
||||
export const mockSuccessResult: PaginationResult = {
|
||||
|
||||
@@ -94,7 +94,7 @@ export const FirstPageWithTonsOfData: Story = {
|
||||
currentPage: 2,
|
||||
currentOffsetStart: 1000,
|
||||
totalRecords: 123_456,
|
||||
totalPages: 1235,
|
||||
totalPages: 4939,
|
||||
hasPreviousPage: false,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
@@ -135,3 +135,54 @@ export const SecondPageWithData: Story = {
|
||||
children: <div>New data for page 2</div>,
|
||||
},
|
||||
};
|
||||
|
||||
export const CappedCountFirstPage: Story = {
|
||||
args: {
|
||||
query: {
|
||||
...mockPaginationResultBase,
|
||||
isSuccess: true,
|
||||
currentPage: 1,
|
||||
currentOffsetStart: 1,
|
||||
totalRecords: 2000,
|
||||
totalPages: 80,
|
||||
hasPreviousPage: false,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
countIsCapped: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CappedCountMiddlePage: Story = {
|
||||
args: {
|
||||
query: {
|
||||
...mockPaginationResultBase,
|
||||
isSuccess: true,
|
||||
currentPage: 3,
|
||||
currentOffsetStart: 51,
|
||||
totalRecords: 2000,
|
||||
totalPages: 80,
|
||||
hasPreviousPage: true,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
countIsCapped: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CappedCountBeyondKnownPages: Story = {
|
||||
args: {
|
||||
query: {
|
||||
...mockPaginationResultBase,
|
||||
isSuccess: true,
|
||||
currentPage: 85,
|
||||
currentOffsetStart: 2101,
|
||||
totalRecords: 2000,
|
||||
totalPages: 85,
|
||||
hasPreviousPage: true,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
countIsCapped: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -27,12 +27,14 @@ export const PaginationContainer: FC<PaginationProps> = ({
|
||||
totalRecords={query.totalRecords}
|
||||
currentOffsetStart={query.currentOffsetStart}
|
||||
paginationUnitLabel={paginationUnitLabel}
|
||||
countIsCapped={query.countIsCapped}
|
||||
className="justify-end"
|
||||
/>
|
||||
|
||||
{query.isSuccess && (
|
||||
<PaginationWidgetBase
|
||||
totalRecords={query.totalRecords}
|
||||
totalPages={query.totalPages}
|
||||
currentPage={query.currentPage}
|
||||
pageSize={query.limit}
|
||||
onPageChange={query.onPageChange}
|
||||
|
||||
@@ -12,6 +12,10 @@ export type PaginationWidgetBaseProps = {
|
||||
|
||||
hasPreviousPage?: boolean;
|
||||
hasNextPage?: boolean;
|
||||
/** Override the computed totalPages.
|
||||
* Used when, e.g., the row count is capped and the user navigates beyond
|
||||
* the known range, so totalPages stays at least as high as currentPage. */
|
||||
totalPages?: number;
|
||||
};
|
||||
|
||||
export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
|
||||
@@ -21,8 +25,9 @@ export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
|
||||
onPageChange,
|
||||
hasPreviousPage,
|
||||
hasNextPage,
|
||||
totalPages: totalPagesProp,
|
||||
}) => {
|
||||
const totalPages = Math.ceil(totalRecords / pageSize);
|
||||
const totalPages = totalPagesProp ?? Math.ceil(totalRecords / pageSize);
|
||||
|
||||
if (totalPages < 2) {
|
||||
return null;
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
export { useTabOverflowKebabMenu } from "./useTabOverflowKebabMenu";
|
||||
@@ -0,0 +1,274 @@
|
||||
import {
|
||||
type RefObject,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
type TabValue = {
|
||||
value: string;
|
||||
};
|
||||
|
||||
type UseKebabMenuOptions<T extends TabValue> = {
|
||||
tabs: readonly T[];
|
||||
enabled: boolean;
|
||||
isActive: boolean;
|
||||
overflowTriggerWidth?: number;
|
||||
};
|
||||
|
||||
type UseKebabMenuResult<T extends TabValue> = {
|
||||
containerRef: RefObject<HTMLDivElement | null>;
|
||||
visibleTabs: T[];
|
||||
overflowTabs: T[];
|
||||
getTabMeasureProps: (tabValue: string) => Record<string, string>;
|
||||
};
|
||||
|
||||
const ALWAYS_VISIBLE_TABS_COUNT = 1;
|
||||
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
|
||||
|
||||
/**
|
||||
* Splits tabs into visible and overflow groups based on container width.
|
||||
*
|
||||
* Tabs must render with `getTabMeasureProps()` so this hook can measure
|
||||
* trigger widths from the DOM.
|
||||
*/
|
||||
export const useKebabMenu = <T extends TabValue>({
|
||||
tabs,
|
||||
enabled,
|
||||
isActive,
|
||||
overflowTriggerWidth = 44,
|
||||
}: UseKebabMenuOptions<T>): UseKebabMenuResult<T> => {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const tabsRef = useRef<readonly T[]>(tabs);
|
||||
tabsRef.current = tabs;
|
||||
const previousTabsRef = useRef<readonly T[]>(tabs);
|
||||
const availableWidthRef = useRef<number | null>(null);
|
||||
// Width cache prevents oscillation when overflow tabs are not mounted.
|
||||
const tabWidthByValueRef = useRef<Record<string, number>>({});
|
||||
const [overflowTabValues, setTabValues] = useState<string[]>([]);
|
||||
|
||||
const recalculateOverflow = useCallback(
|
||||
(availableWidth: number) => {
|
||||
if (!enabled || !isActive) {
|
||||
// Keep this update idempotent to avoid render loops.
|
||||
setTabValues((currentValues) => {
|
||||
if (currentValues.length === 0) {
|
||||
return currentValues;
|
||||
}
|
||||
return [];
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
const currentTabs = tabsRef.current;
|
||||
|
||||
const tabWidthByValue = measureTabWidths({
|
||||
tabs: currentTabs,
|
||||
container,
|
||||
previousTabWidthByValue: tabWidthByValueRef.current,
|
||||
});
|
||||
tabWidthByValueRef.current = tabWidthByValue;
|
||||
|
||||
const nextOverflowValues = calculateTabValues({
|
||||
tabs: currentTabs,
|
||||
availableWidth,
|
||||
tabWidthByValue,
|
||||
overflowTriggerWidth,
|
||||
});
|
||||
|
||||
setTabValues((currentValues) => {
|
||||
// Avoid state updates when the computed overflow did not change.
|
||||
if (areStringArraysEqual(currentValues, nextOverflowValues)) {
|
||||
return currentValues;
|
||||
}
|
||||
return nextOverflowValues;
|
||||
});
|
||||
},
|
||||
[enabled, isActive, overflowTriggerWidth],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (previousTabsRef.current === tabs) {
|
||||
// No change in tabs, no need to recalculate.
|
||||
return;
|
||||
}
|
||||
previousTabsRef.current = tabs;
|
||||
if (availableWidthRef.current === null) {
|
||||
// First mount, no width available yet.
|
||||
return;
|
||||
}
|
||||
recalculateOverflow(availableWidthRef.current);
|
||||
}, [recalculateOverflow, tabs]);
|
||||
|
||||
useEffect(() => {
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Recompute whenever ResizeObserver reports a container width change.
|
||||
const observer = new ResizeObserver(([entry]) => {
|
||||
if (!entry) {
|
||||
return;
|
||||
}
|
||||
availableWidthRef.current = entry.contentRect.width;
|
||||
recalculateOverflow(entry.contentRect.width);
|
||||
});
|
||||
observer.observe(container);
|
||||
return () => observer.disconnect();
|
||||
}, [recalculateOverflow]);
|
||||
|
||||
const overflowTabValuesSet = new Set(overflowTabValues);
|
||||
const { visibleTabs, overflowTabs } = tabs.reduce<{
|
||||
visibleTabs: T[];
|
||||
overflowTabs: T[];
|
||||
}>(
|
||||
(tabGroups, tab) => {
|
||||
if (overflowTabValuesSet.has(tab.value)) {
|
||||
tabGroups.overflowTabs.push(tab);
|
||||
} else {
|
||||
tabGroups.visibleTabs.push(tab);
|
||||
}
|
||||
return tabGroups;
|
||||
},
|
||||
{ visibleTabs: [], overflowTabs: [] },
|
||||
);
|
||||
|
||||
const getTabMeasureProps = (tabValue: string) => {
|
||||
return { [DATA_ATTR_TAB_VALUE]: tabValue };
|
||||
};
|
||||
|
||||
return {
|
||||
containerRef,
|
||||
visibleTabs,
|
||||
overflowTabs,
|
||||
getTabMeasureProps,
|
||||
};
|
||||
};
|
||||
|
||||
const calculateTabValues = <T extends TabValue>({
|
||||
tabs,
|
||||
availableWidth,
|
||||
tabWidthByValue,
|
||||
overflowTriggerWidth,
|
||||
}: {
|
||||
tabs: readonly T[];
|
||||
availableWidth: number;
|
||||
tabWidthByValue: Readonly<Record<string, number>>;
|
||||
overflowTriggerWidth: number;
|
||||
}): string[] => {
|
||||
const tabWidthByValueMap = new Map<string, number>();
|
||||
for (const tab of tabs) {
|
||||
tabWidthByValueMap.set(tab.value, tabWidthByValue[tab.value] ?? 0);
|
||||
}
|
||||
|
||||
const firstOptionalTabIndex = Math.min(
|
||||
ALWAYS_VISIBLE_TABS_COUNT,
|
||||
tabs.length,
|
||||
);
|
||||
if (firstOptionalTabIndex >= tabs.length) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const alwaysVisibleTabs = tabs.slice(0, firstOptionalTabIndex);
|
||||
const optionalTabs = tabs.slice(firstOptionalTabIndex);
|
||||
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
|
||||
return total + (tabWidthByValueMap.get(tab.value) ?? 0);
|
||||
}, 0);
|
||||
const firstTabIndex = findFirstTabIndex({
|
||||
optionalTabs,
|
||||
optionalTabWidths: optionalTabs.map((tab) => {
|
||||
return tabWidthByValueMap.get(tab.value) ?? 0;
|
||||
}),
|
||||
startingUsedWidth: alwaysVisibleWidth,
|
||||
availableWidth,
|
||||
overflowTriggerWidth,
|
||||
});
|
||||
|
||||
if (firstTabIndex === -1) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return optionalTabs
|
||||
.slice(firstTabIndex)
|
||||
.map((overflowTab) => overflowTab.value);
|
||||
};
|
||||
|
||||
const measureTabWidths = <T extends TabValue>({
|
||||
tabs,
|
||||
container,
|
||||
previousTabWidthByValue,
|
||||
}: {
|
||||
tabs: readonly T[];
|
||||
container: HTMLDivElement;
|
||||
previousTabWidthByValue: Readonly<Record<string, number>>;
|
||||
}): Record<string, number> => {
|
||||
const nextTabWidthByValue = { ...previousTabWidthByValue };
|
||||
for (const tab of tabs) {
|
||||
const tabElement = container.querySelector<HTMLElement>(
|
||||
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
|
||||
);
|
||||
if (tabElement) {
|
||||
nextTabWidthByValue[tab.value] = tabElement.offsetWidth;
|
||||
}
|
||||
}
|
||||
return nextTabWidthByValue;
|
||||
};
|
||||
|
||||
const findFirstTabIndex = ({
|
||||
optionalTabs,
|
||||
optionalTabWidths,
|
||||
startingUsedWidth,
|
||||
availableWidth,
|
||||
overflowTriggerWidth,
|
||||
}: {
|
||||
optionalTabs: readonly TabValue[];
|
||||
optionalTabWidths: readonly number[];
|
||||
startingUsedWidth: number;
|
||||
availableWidth: number;
|
||||
overflowTriggerWidth: number;
|
||||
}): number => {
|
||||
const result = optionalTabs.reduce(
|
||||
(acc, _tab, index) => {
|
||||
if (acc.firstTabIndex !== -1) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
const tabWidth = optionalTabWidths[index] ?? 0;
|
||||
const hasMoreTabs = index < optionalTabs.length - 1;
|
||||
// Reserve kebab trigger width whenever additional tabs remain.
|
||||
const widthNeeded =
|
||||
acc.usedWidth + tabWidth + (hasMoreTabs ? overflowTriggerWidth : 0);
|
||||
|
||||
if (widthNeeded <= availableWidth) {
|
||||
return {
|
||||
usedWidth: acc.usedWidth + tabWidth,
|
||||
firstTabIndex: -1,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
usedWidth: acc.usedWidth,
|
||||
firstTabIndex: index,
|
||||
};
|
||||
},
|
||||
{ usedWidth: startingUsedWidth, firstTabIndex: -1 },
|
||||
);
|
||||
|
||||
return result.firstTabIndex;
|
||||
};
|
||||
|
||||
const areStringArraysEqual = (
|
||||
left: readonly string[],
|
||||
right: readonly string[],
|
||||
): boolean => {
|
||||
return (
|
||||
left.length === right.length &&
|
||||
left.every((value, index) => value === right[index])
|
||||
);
|
||||
};
|
||||
@@ -1,157 +0,0 @@
|
||||
import {
|
||||
type RefObject,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
type TabLike = {
|
||||
value: string;
|
||||
};
|
||||
|
||||
type UseTabOverflowKebabMenuOptions<TTab extends TabLike> = {
|
||||
tabs: readonly TTab[];
|
||||
enabled: boolean;
|
||||
isActive: boolean;
|
||||
alwaysVisibleTabsCount?: number;
|
||||
overflowTriggerWidthPx?: number;
|
||||
};
|
||||
|
||||
type UseTabOverflowKebabMenuResult<TTab extends TabLike> = {
|
||||
containerRef: RefObject<HTMLDivElement | null>;
|
||||
visibleTabs: TTab[];
|
||||
overflowTabs: TTab[];
|
||||
getTabMeasureProps: (tabValue: string) => Record<string, string>;
|
||||
};
|
||||
|
||||
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
|
||||
|
||||
export const useTabOverflowKebabMenu = <TTab extends TabLike>({
|
||||
tabs,
|
||||
enabled,
|
||||
isActive,
|
||||
alwaysVisibleTabsCount = 1,
|
||||
overflowTriggerWidthPx = 44,
|
||||
}: UseTabOverflowKebabMenuOptions<TTab>): UseTabOverflowKebabMenuResult<TTab> => {
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const tabWidthByValueRef = useRef<Record<string, number>>({});
|
||||
const [overflowTabValues, setOverflowTabValues] = useState<string[]>([]);
|
||||
|
||||
const recalculateOverflow = useCallback(() => {
|
||||
if (!enabled) {
|
||||
setOverflowTabValues([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const tab of tabs) {
|
||||
const tabElement = container.querySelector<HTMLElement>(
|
||||
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
|
||||
);
|
||||
if (tabElement) {
|
||||
tabWidthByValueRef.current[tab.value] = tabElement.offsetWidth;
|
||||
}
|
||||
}
|
||||
|
||||
const alwaysVisibleTabs = tabs.slice(0, alwaysVisibleTabsCount);
|
||||
const optionalTabs = tabs.slice(alwaysVisibleTabsCount);
|
||||
if (optionalTabs.length === 0) {
|
||||
setOverflowTabValues([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
|
||||
return total + (tabWidthByValueRef.current[tab.value] ?? 0);
|
||||
}, 0);
|
||||
|
||||
const availableWidth = container.clientWidth;
|
||||
let usedWidth = alwaysVisibleWidth;
|
||||
const nextOverflowValues: string[] = [];
|
||||
|
||||
for (let i = 0; i < optionalTabs.length; i++) {
|
||||
const tab = optionalTabs[i];
|
||||
const tabWidth = tabWidthByValueRef.current[tab.value] ?? 0;
|
||||
const hasMoreTabsAfterCurrent = i < optionalTabs.length - 1;
|
||||
const widthNeeded =
|
||||
usedWidth +
|
||||
tabWidth +
|
||||
(hasMoreTabsAfterCurrent ? overflowTriggerWidthPx : 0);
|
||||
|
||||
if (widthNeeded <= availableWidth) {
|
||||
usedWidth += tabWidth;
|
||||
continue;
|
||||
}
|
||||
|
||||
nextOverflowValues.push(
|
||||
...optionalTabs.slice(i).map((overflowTab) => overflowTab.value),
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
setOverflowTabValues((currentValues) => {
|
||||
if (
|
||||
currentValues.length === nextOverflowValues.length &&
|
||||
currentValues.every(
|
||||
(value, index) => value === nextOverflowValues[index],
|
||||
)
|
||||
) {
|
||||
return currentValues;
|
||||
}
|
||||
return nextOverflowValues;
|
||||
});
|
||||
}, [alwaysVisibleTabsCount, enabled, overflowTriggerWidthPx, tabs]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (!isActive) {
|
||||
return;
|
||||
}
|
||||
recalculateOverflow();
|
||||
}, [isActive, recalculateOverflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isActive) {
|
||||
return;
|
||||
}
|
||||
const container = containerRef.current;
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
const observer = new ResizeObserver(() => {
|
||||
recalculateOverflow();
|
||||
});
|
||||
observer.observe(container);
|
||||
return () => observer.disconnect();
|
||||
}, [isActive, recalculateOverflow]);
|
||||
|
||||
const overflowTabValuesSet = useMemo(
|
||||
() => new Set(overflowTabValues),
|
||||
[overflowTabValues],
|
||||
);
|
||||
|
||||
const visibleTabs = useMemo(
|
||||
() => tabs.filter((tab) => !overflowTabValuesSet.has(tab.value)),
|
||||
[tabs, overflowTabValuesSet],
|
||||
);
|
||||
const overflowTabs = useMemo(
|
||||
() => tabs.filter((tab) => overflowTabValuesSet.has(tab.value)),
|
||||
[tabs, overflowTabValuesSet],
|
||||
);
|
||||
|
||||
const getTabMeasureProps = useCallback((tabValue: string) => {
|
||||
return { [DATA_ATTR_TAB_VALUE]: tabValue };
|
||||
}, []);
|
||||
|
||||
return {
|
||||
containerRef,
|
||||
visibleTabs,
|
||||
overflowTabs,
|
||||
getTabMeasureProps,
|
||||
};
|
||||
};
|
||||
@@ -258,6 +258,78 @@ describe(usePaginatedQuery.name, () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("Capped count behavior", () => {
|
||||
const mockQueryKey = vi.fn(() => ["mock"]);
|
||||
|
||||
// Returns count 2001 (capped) with items on pages up to page 84
|
||||
// (84 * 25 = 2100 items total).
|
||||
const mockCappedQueryFn = vi.fn(({ pageNumber, limit }) => {
|
||||
const totalItems = 2100;
|
||||
const offset = (pageNumber - 1) * limit;
|
||||
// Returns 0 items when the requested page is past the end, simulating
|
||||
// an empty server response.
|
||||
const itemsOnPage = Math.max(0, Math.min(limit, totalItems - offset));
|
||||
return Promise.resolve({
|
||||
data: new Array(itemsOnPage).fill(pageNumber),
|
||||
count: 2001,
|
||||
count_cap: 2000,
|
||||
});
|
||||
});
|
||||
|
||||
it("Caps totalRecords at 2000 when count exceeds cap", async () => {
|
||||
const { result } = await render({
|
||||
queryKey: mockQueryKey,
|
||||
queryFn: mockCappedQueryFn,
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
expect(result.current.totalRecords).toBe(2000);
|
||||
});
|
||||
|
||||
it("hasNextPage is true when count is capped", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=80",
|
||||
);
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
expect(result.current.hasNextPage).toBe(true);
|
||||
});
|
||||
|
||||
it("hasPreviousPage is true when count is capped and page is beyond cap", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=83",
|
||||
);
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
expect(result.current.hasPreviousPage).toBe(true);
|
||||
});
|
||||
|
||||
it("Does not redirect to last page when count is capped and page is valid", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=83",
|
||||
);
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
// Should stay on page 83 — not redirect to page 80.
|
||||
expect(result.current.currentPage).toBe(83);
|
||||
});
|
||||
|
||||
it("Redirects to last known page when navigating beyond actual data", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=999",
|
||||
);
|
||||
|
||||
// Page 999 has no items. Should redirect to page 81
|
||||
// (ceil(2001 / 25) = 81), the last page guaranteed to
|
||||
// have data.
|
||||
await waitFor(() => expect(result.current.currentPage).toBe(81));
|
||||
});
|
||||
});
|
||||
|
||||
describe("Passing in searchParams property", () => {
|
||||
const mockQueryKey = vi.fn(() => ["mock"]);
|
||||
const mockQueryFn = vi.fn(({ pageNumber, limit }) =>
|
||||
|
||||
@@ -144,16 +144,44 @@ export function usePaginatedQuery<
|
||||
placeholderData: keepPreviousData,
|
||||
});
|
||||
|
||||
const totalRecords = query.data?.count;
|
||||
const totalPages =
|
||||
totalRecords !== undefined ? Math.ceil(totalRecords / limit) : undefined;
|
||||
const count = query.data?.count;
|
||||
const countCap = query.data?.count_cap;
|
||||
const countIsCapped =
|
||||
countCap !== undefined &&
|
||||
countCap > 0 &&
|
||||
count !== undefined &&
|
||||
count > countCap;
|
||||
const totalRecords = countIsCapped ? countCap : count;
|
||||
let totalPages =
|
||||
totalRecords !== undefined
|
||||
? Math.max(
|
||||
Math.ceil(totalRecords / limit),
|
||||
// True count is not known; let them navigate forward
|
||||
// until they hit an empty page (checked below).
|
||||
countIsCapped ? currentPage : 0,
|
||||
)
|
||||
: undefined;
|
||||
|
||||
// When the true count is unknown, the user can navigate past
|
||||
// all actual data. If that happens, we need to redirect (via
|
||||
// updatePageIfInvalid) to the last page guaranteed to be not
|
||||
// empty.
|
||||
const pageIsEmpty =
|
||||
query.data != null &&
|
||||
!Object.values(query.data).some((v) => Array.isArray(v) && v.length > 0);
|
||||
if (pageIsEmpty) {
|
||||
totalPages = count !== undefined ? Math.ceil(count / limit) : 1;
|
||||
}
|
||||
|
||||
const hasNextPage =
|
||||
totalRecords !== undefined && limit + currentPageOffset < totalRecords;
|
||||
totalRecords !== undefined &&
|
||||
((countIsCapped && !pageIsEmpty) ||
|
||||
limit + currentPageOffset < totalRecords);
|
||||
const hasPreviousPage =
|
||||
totalRecords !== undefined &&
|
||||
currentPage > 1 &&
|
||||
currentPageOffset - limit < totalRecords;
|
||||
((countIsCapped && !pageIsEmpty) ||
|
||||
currentPageOffset - limit < totalRecords);
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const prefetchPage = useEffectEvent((newPage: number) => {
|
||||
@@ -224,10 +252,14 @@ export function usePaginatedQuery<
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (!query.isFetching && totalPages !== undefined) {
|
||||
if (
|
||||
!query.isFetching &&
|
||||
totalPages !== undefined &&
|
||||
currentPage > totalPages
|
||||
) {
|
||||
void updatePageIfInvalid(totalPages);
|
||||
}
|
||||
}, [updatePageIfInvalid, query.isFetching, totalPages]);
|
||||
}, [updatePageIfInvalid, query.isFetching, totalPages, currentPage]);
|
||||
|
||||
const onPageChange = (newPage: number) => {
|
||||
// Page 1 is the only page that can be safely navigated to without knowing
|
||||
@@ -236,7 +268,12 @@ export function usePaginatedQuery<
|
||||
return;
|
||||
}
|
||||
|
||||
const cleanedInput = clamp(Math.trunc(newPage), 1, totalPages ?? 1);
|
||||
// If the true count is unknown, we allow navigating past the
|
||||
// known page range.
|
||||
const upperBound = countIsCapped
|
||||
? Number.MAX_SAFE_INTEGER
|
||||
: (totalPages ?? 1);
|
||||
const cleanedInput = clamp(Math.trunc(newPage), 1, upperBound);
|
||||
if (Number.isNaN(cleanedInput)) {
|
||||
return;
|
||||
}
|
||||
@@ -274,6 +311,7 @@ export function usePaginatedQuery<
|
||||
totalRecords: totalRecords as number,
|
||||
totalPages: totalPages as number,
|
||||
currentOffsetStart: currentPageOffset + 1,
|
||||
countIsCapped,
|
||||
}
|
||||
: {
|
||||
isSuccess: false,
|
||||
@@ -282,6 +320,7 @@ export function usePaginatedQuery<
|
||||
totalRecords: undefined,
|
||||
totalPages: undefined,
|
||||
currentOffsetStart: undefined,
|
||||
countIsCapped: false as const,
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -323,6 +362,7 @@ export type PaginationResultInfo = {
|
||||
totalRecords: undefined;
|
||||
totalPages: undefined;
|
||||
currentOffsetStart: undefined;
|
||||
countIsCapped: false;
|
||||
}
|
||||
| {
|
||||
isSuccess: true;
|
||||
@@ -331,6 +371,7 @@ export type PaginationResultInfo = {
|
||||
totalRecords: number;
|
||||
totalPages: number;
|
||||
currentOffsetStart: number;
|
||||
countIsCapped: boolean;
|
||||
}
|
||||
);
|
||||
|
||||
@@ -417,6 +458,7 @@ type QueryPageParamsWithPayload<TPayload = never> = QueryPageParams & {
|
||||
*/
|
||||
export type PaginatedData = {
|
||||
count: number;
|
||||
count_cap?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { expect, spyOn, userEvent, waitFor, within } from "storybook/test";
|
||||
import { API } from "#/api/api";
|
||||
import { workspaceAgentContainersKey } from "#/api/queries/workspaces";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import type { WorkspaceAgentLogSource } from "#/api/typesGenerated";
|
||||
import { getPreferredProxy } from "#/contexts/ProxyContext";
|
||||
import { chromatic } from "#/testHelpers/chromatic";
|
||||
import * as M from "#/testHelpers/entities";
|
||||
@@ -76,6 +76,8 @@ const defaultAgentMetadata = [
|
||||
},
|
||||
];
|
||||
|
||||
const fixedLogTimestamp = "2021-05-05T00:00:00.000Z";
|
||||
|
||||
const logs = [
|
||||
"\x1b[91mCloning Git repository...",
|
||||
"\x1b[2;37;41mStarting Docker Daemon...",
|
||||
@@ -87,10 +89,10 @@ const logs = [
|
||||
level: "info",
|
||||
output: line,
|
||||
source_id: M.MockWorkspaceAgentLogSource.id,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: fixedLogTimestamp,
|
||||
}));
|
||||
|
||||
const installScriptLogSource: TypesGen.WorkspaceAgentLogSource = {
|
||||
const installScriptLogSource: WorkspaceAgentLogSource = {
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "f2ee4b8d-b09d-4f4e-a1f1-5e4adf7d53bb",
|
||||
display_name: "Install Script",
|
||||
@@ -102,60 +104,24 @@ const tabbedLogs = [
|
||||
level: "info",
|
||||
output: "startup: preparing workspace",
|
||||
source_id: M.MockWorkspaceAgentLogSource.id,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: fixedLogTimestamp,
|
||||
},
|
||||
{
|
||||
id: 101,
|
||||
level: "info",
|
||||
output: "install: pnpm install",
|
||||
source_id: installScriptLogSource.id,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: fixedLogTimestamp,
|
||||
},
|
||||
{
|
||||
id: 102,
|
||||
level: "info",
|
||||
output: "install: setup complete",
|
||||
source_id: installScriptLogSource.id,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: fixedLogTimestamp,
|
||||
},
|
||||
];
|
||||
|
||||
const overflowLogSources: TypesGen.WorkspaceAgentLogSource[] = [
|
||||
M.MockWorkspaceAgentLogSource,
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "58f5db69-5f78-496f-bce1-0686f5525aa1",
|
||||
display_name: "code-server",
|
||||
icon: "/icon/code.svg",
|
||||
},
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "f39d758c-bce2-4f41-8d70-58fdb1f0f729",
|
||||
display_name: "Install and start AgentAPI",
|
||||
icon: "/icon/claude.svg",
|
||||
},
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "bf7529b8-1787-4a20-b54f-eb894680e48f",
|
||||
display_name: "Mux",
|
||||
icon: "/icon/mux.svg",
|
||||
},
|
||||
{
|
||||
...M.MockWorkspaceAgentLogSource,
|
||||
id: "0d6ebde6-c534-4551-9f91-bfd98bfb04f4",
|
||||
display_name: "Portable Desktop",
|
||||
icon: "/icon/portable-desktop.svg",
|
||||
},
|
||||
];
|
||||
|
||||
const overflowLogs = overflowLogSources.map((source, index) => ({
|
||||
id: 200 + index,
|
||||
level: "info",
|
||||
output: `${source.display_name}: line`,
|
||||
source_id: source.id,
|
||||
created_at: new Date().toISOString(),
|
||||
}));
|
||||
|
||||
const meta: Meta<typeof AgentRow> = {
|
||||
title: "components/AgentRow",
|
||||
component: AgentRow,
|
||||
@@ -438,44 +404,3 @@ export const LogsTabs: Story = {
|
||||
await expect(canvas.getByText("install: pnpm install")).toBeVisible();
|
||||
},
|
||||
};
|
||||
|
||||
export const LogsTabsOverflow: Story = {
|
||||
args: {
|
||||
agent: {
|
||||
...M.MockWorkspaceAgentReady,
|
||||
logs_length: overflowLogs.length,
|
||||
log_sources: overflowLogSources,
|
||||
},
|
||||
},
|
||||
parameters: {
|
||||
webSocket: [
|
||||
{
|
||||
event: "message",
|
||||
data: JSON.stringify(overflowLogs),
|
||||
},
|
||||
],
|
||||
},
|
||||
render: (args) => (
|
||||
<div className="max-w-[320px]">
|
||||
<AgentRow {...args} />
|
||||
</div>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const page = within(canvasElement.ownerDocument.body);
|
||||
await userEvent.click(canvas.getByRole("button", { name: "Logs" }));
|
||||
await userEvent.click(
|
||||
canvas.getByRole("button", { name: "More log tabs" }),
|
||||
);
|
||||
const overflowItems = await page.findAllByRole("menuitemradio");
|
||||
const selectedItem = overflowItems[0];
|
||||
const selectedSource = selectedItem.textContent;
|
||||
if (!selectedSource) {
|
||||
throw new Error("Overflow menu item must have text content.");
|
||||
}
|
||||
await userEvent.click(selectedItem);
|
||||
await waitFor(() =>
|
||||
expect(canvas.getByText(`${selectedSource}: line`)).toBeVisible(),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
@@ -8,10 +8,9 @@ import {
|
||||
} from "lucide-react";
|
||||
import {
|
||||
type FC,
|
||||
useCallback,
|
||||
type ReactNode,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
@@ -42,7 +41,7 @@ import {
|
||||
TabsList,
|
||||
TabsTrigger,
|
||||
} from "#/components/Tabs/Tabs";
|
||||
import { useTabOverflowKebabMenu } from "#/components/Tabs/utils";
|
||||
import { useKebabMenu } from "#/components/Tabs/utils/useKebabMenu";
|
||||
import { useProxy } from "#/contexts/ProxyContext";
|
||||
import { useClipboard } from "#/hooks/useClipboard";
|
||||
import { useFeatureVisibility } from "#/modules/dashboard/useFeatureVisibility";
|
||||
@@ -162,7 +161,7 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
// This is a bit of a hack on the react-window API to get the scroll position.
|
||||
// If we're scrolled to the bottom, we want to keep the list scrolled to the bottom.
|
||||
// This makes it feel similar to a terminal that auto-scrolls downwards!
|
||||
const handleLogScroll = useCallback((props: ListOnScrollProps) => {
|
||||
const handleLogScroll = (props: ListOnScrollProps) => {
|
||||
if (
|
||||
props.scrollOffset === 0 ||
|
||||
props.scrollUpdateWasRequested ||
|
||||
@@ -179,7 +178,7 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
logListDivRef.current.scrollHeight -
|
||||
(props.scrollOffset + parent.clientHeight);
|
||||
setBottomOfLogs(distanceFromBottom < AGENT_LOG_LINE_HEIGHT);
|
||||
}, []);
|
||||
};
|
||||
|
||||
const devcontainers = useAgentContainers(agent);
|
||||
|
||||
@@ -211,59 +210,56 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
);
|
||||
|
||||
const [selectedLogTab, setSelectedLogTab] = useState("all");
|
||||
const logTabs = useMemo(() => {
|
||||
const sourceLogTabs = agent.log_sources
|
||||
.filter((logSource) => {
|
||||
// Remove the logSources that have no entries.
|
||||
return agentLogs.some(
|
||||
(log) =>
|
||||
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
|
||||
);
|
||||
})
|
||||
.map((logSource) => ({
|
||||
// Show the icon for the log source if it has one.
|
||||
// In the startup script case, we show a bespoke play icon.
|
||||
startIcon: logSource.icon ? (
|
||||
<ExternalImage
|
||||
src={logSource.icon}
|
||||
alt=""
|
||||
className="size-icon-xs shrink-0"
|
||||
/>
|
||||
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
|
||||
<PlayIcon className="size-icon-xs shrink-0" />
|
||||
) : null,
|
||||
title: logSource.display_name,
|
||||
value: logSource.id,
|
||||
}));
|
||||
const startupScriptLogTab = sourceLogTabs.find(
|
||||
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
|
||||
);
|
||||
const sortedSourceLogTabs = sourceLogTabs
|
||||
.filter((tab) => tab !== startupScriptLogTab)
|
||||
.sort((a, b) => a.title.localeCompare(b.title));
|
||||
return [
|
||||
{
|
||||
title: "All Logs",
|
||||
value: "all",
|
||||
},
|
||||
...(startupScriptLogTab ? [startupScriptLogTab] : []),
|
||||
...sortedSourceLogTabs,
|
||||
] as {
|
||||
startIcon?: React.ReactNode;
|
||||
title: string;
|
||||
value: string;
|
||||
}[];
|
||||
}, [agent.log_sources, agentLogs]);
|
||||
const sourceLogTabs = agent.log_sources
|
||||
.filter((logSource) => {
|
||||
// Remove the logSources that have no entries.
|
||||
return agentLogs.some(
|
||||
(log) =>
|
||||
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
|
||||
);
|
||||
})
|
||||
.map((logSource) => ({
|
||||
// Show the icon for the log source if it has one.
|
||||
// In the startup script case, we show a bespoke play icon.
|
||||
startIcon: logSource.icon ? (
|
||||
<ExternalImage
|
||||
src={logSource.icon}
|
||||
alt=""
|
||||
className="size-icon-xs shrink-0"
|
||||
/>
|
||||
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
|
||||
<PlayIcon className="size-icon-xs shrink-0" />
|
||||
) : null,
|
||||
title: logSource.display_name,
|
||||
value: logSource.id,
|
||||
}));
|
||||
const startupScriptLogTab = sourceLogTabs.find(
|
||||
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
|
||||
);
|
||||
const sortedSourceLogTabs = sourceLogTabs
|
||||
.filter((tab) => tab !== startupScriptLogTab)
|
||||
.sort((a, b) => a.title.localeCompare(b.title));
|
||||
const logTabs: {
|
||||
startIcon?: ReactNode;
|
||||
title: string;
|
||||
value: string;
|
||||
}[] = [
|
||||
{
|
||||
title: "All Logs",
|
||||
value: "all",
|
||||
},
|
||||
...(startupScriptLogTab ? [startupScriptLogTab] : []),
|
||||
...sortedSourceLogTabs,
|
||||
];
|
||||
const {
|
||||
containerRef: logTabsListContainerRef,
|
||||
visibleTabs: visibleLogTabs,
|
||||
overflowTabs: overflowLogTabs,
|
||||
getTabMeasureProps,
|
||||
} = useTabOverflowKebabMenu({
|
||||
} = useKebabMenu({
|
||||
tabs: logTabs,
|
||||
enabled: true,
|
||||
isActive: showLogs,
|
||||
alwaysVisibleTabsCount: 1,
|
||||
});
|
||||
const overflowLogTabValuesSet = new Set(
|
||||
overflowLogTabs.map((tab) => tab.value),
|
||||
@@ -279,16 +275,29 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
level: log.level,
|
||||
sourceId: log.source_id,
|
||||
}));
|
||||
const allLogsText = agentLogs.map((log) => log.output).join("\n");
|
||||
const selectedLogsText = selectedLogs.map((log) => log.output).join("\n");
|
||||
const hasSelectedLogs = selectedLogs.length > 0;
|
||||
const hasAnyLogs = agentLogs.length > 0;
|
||||
const { showCopiedSuccess, copyToClipboard } = useClipboard();
|
||||
const selectedLogTabTitle =
|
||||
logTabs.find((tab) => tab.value === selectedLogTab)?.title ?? "Logs";
|
||||
const sanitizedTabTitle = selectedLogTabTitle
|
||||
.toLowerCase()
|
||||
.replaceAll(/[^a-z0-9]+/g, "-")
|
||||
.replaceAll(/(^-|-$)/g, "");
|
||||
const logFilenameSuffix = sanitizedTabTitle || "logs";
|
||||
const downloadableLogSets = logTabs
|
||||
.filter((tab) => tab.value !== "all")
|
||||
.map((tab) => {
|
||||
const logsText = agentLogs
|
||||
.filter((log) => log.source_id === tab.value)
|
||||
.map((log) => log.output)
|
||||
.join("\n");
|
||||
const filenameSuffix = tab.title
|
||||
.toLowerCase()
|
||||
.replaceAll(/[^a-z0-9]+/g, "-")
|
||||
.replaceAll(/(^-|-$)/g, "");
|
||||
return {
|
||||
label: tab.title,
|
||||
filenameSuffix: filenameSuffix || tab.value,
|
||||
logsText,
|
||||
startIcon: tab.startIcon,
|
||||
};
|
||||
});
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -547,9 +556,9 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
</Button>
|
||||
<DownloadSelectedAgentLogsButton
|
||||
agentName={agent.name}
|
||||
filenameSuffix={logFilenameSuffix}
|
||||
logsText={selectedLogsText}
|
||||
disabled={!hasSelectedLogs}
|
||||
logSets={downloadableLogSets}
|
||||
allLogsText={allLogsText}
|
||||
disabled={!hasAnyLogs}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
import { saveAs } from "file-saver";
|
||||
import { DownloadIcon } from "lucide-react";
|
||||
import { type FC, useState } from "react";
|
||||
import { ChevronDownIcon, DownloadIcon, PackageIcon } from "lucide-react";
|
||||
import { type FC, type ReactNode, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { getErrorDetail } from "#/api/errors";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "#/components/DropdownMenu/DropdownMenu";
|
||||
|
||||
type DownloadableLogSet = {
|
||||
label: string;
|
||||
filenameSuffix: string;
|
||||
logsText: string;
|
||||
startIcon?: ReactNode;
|
||||
};
|
||||
|
||||
type DownloadSelectedAgentLogsButtonProps = {
|
||||
agentName: string;
|
||||
filenameSuffix: string;
|
||||
logsText: string;
|
||||
logSets: readonly DownloadableLogSet[];
|
||||
allLogsText: string;
|
||||
disabled?: boolean;
|
||||
download?: (file: Blob, filename: string) => void | Promise<void>;
|
||||
};
|
||||
@@ -17,13 +30,13 @@ export const DownloadSelectedAgentLogsButton: FC<
|
||||
DownloadSelectedAgentLogsButtonProps
|
||||
> = ({
|
||||
agentName,
|
||||
filenameSuffix,
|
||||
logsText,
|
||||
logSets,
|
||||
allLogsText,
|
||||
disabled = false,
|
||||
download = saveAs,
|
||||
}) => {
|
||||
const [isDownloading, setIsDownloading] = useState(false);
|
||||
const handleDownload = async () => {
|
||||
const downloadLogs = async (logsText: string, filenameSuffix: string) => {
|
||||
try {
|
||||
setIsDownloading(true);
|
||||
const file = new Blob([logsText], { type: "text/plain" });
|
||||
@@ -37,15 +50,40 @@ export const DownloadSelectedAgentLogsButton: FC<
|
||||
}
|
||||
};
|
||||
|
||||
const hasAllLogs = allLogsText.length > 0;
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant="subtle"
|
||||
size="sm"
|
||||
disabled={disabled || isDownloading}
|
||||
onClick={handleDownload}
|
||||
>
|
||||
<DownloadIcon />
|
||||
{isDownloading ? "Downloading..." : "Download logs"}
|
||||
</Button>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="subtle" size="sm" disabled={disabled || isDownloading}>
|
||||
<DownloadIcon />
|
||||
{isDownloading ? "Downloading..." : "Download logs"}
|
||||
<ChevronDownIcon className="size-icon-sm" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
disabled={!hasAllLogs}
|
||||
onSelect={() => {
|
||||
downloadLogs(allLogsText, "all-logs");
|
||||
}}
|
||||
>
|
||||
<PackageIcon />
|
||||
Download all logs
|
||||
</DropdownMenuItem>
|
||||
{logSets.map((logSet) => (
|
||||
<DropdownMenuItem
|
||||
key={logSet.filenameSuffix}
|
||||
disabled={logSet.logsText.length === 0}
|
||||
onSelect={() => {
|
||||
downloadLogs(logSet.logsText, logSet.filenameSuffix);
|
||||
}}
|
||||
>
|
||||
{logSet.startIcon}
|
||||
<span>Download {logSet.label}</span>
|
||||
</DropdownMenuItem>
|
||||
))}
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -646,17 +646,24 @@ const AgentChatPage: FC = () => {
|
||||
const isRegenerateTitleDisabled = isArchived || isRegeneratingThisChat;
|
||||
const chatLastModelConfigID = chatRecord?.last_model_config_id;
|
||||
|
||||
const sendMutation = useMutation(
|
||||
// Destructure mutation results directly so the React Compiler
|
||||
// tracks stable primitives/functions instead of the whole result
|
||||
// object (TanStack Query v5 recreates it every render via object
|
||||
// spread). Keeping no intermediate variable prevents future code
|
||||
// from accidentally closing over the unstable object.
|
||||
const { isPending: isSendPending, mutateAsync: sendMessage } = useMutation(
|
||||
createChatMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
const editMutation = useMutation(editChatMessage(queryClient, agentId ?? ""));
|
||||
const interruptMutation = useMutation(
|
||||
const { isPending: isEditPending, mutateAsync: editMessage } = useMutation(
|
||||
editChatMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
const { isPending: isInterruptPending, mutateAsync: interrupt } = useMutation(
|
||||
interruptChat(queryClient, agentId ?? ""),
|
||||
);
|
||||
const deleteQueuedMutation = useMutation(
|
||||
const { mutateAsync: deleteQueuedMessage } = useMutation(
|
||||
deleteChatQueuedMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
const promoteQueuedMutation = useMutation(
|
||||
const { mutateAsync: promoteQueuedMessage } = useMutation(
|
||||
promoteChatQueuedMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
|
||||
@@ -754,9 +761,7 @@ const AgentChatPage: FC = () => {
|
||||
hasUserFixableModelProviders,
|
||||
});
|
||||
const isSubmissionPending =
|
||||
sendMutation.isPending ||
|
||||
editMutation.isPending ||
|
||||
interruptMutation.isPending;
|
||||
isSendPending || isEditPending || isInterruptPending;
|
||||
const isInputDisabled = !hasModelOptions || isArchived;
|
||||
|
||||
const handleUsageLimitError = (error: unknown): void => {
|
||||
@@ -842,7 +847,7 @@ const AgentChatPage: FC = () => {
|
||||
setPendingEditMessageId(editedMessageID);
|
||||
scrollToBottomRef.current?.();
|
||||
try {
|
||||
await editMutation.mutateAsync({
|
||||
await editMessage({
|
||||
messageId: editedMessageID,
|
||||
req: request,
|
||||
});
|
||||
@@ -873,9 +878,9 @@ const AgentChatPage: FC = () => {
|
||||
// For queued sends the WebSocket status events handle
|
||||
// clearing; for non-queued sends we clear explicitly
|
||||
// below. Clearing eagerly causes a visible cutoff.
|
||||
let response: Awaited<ReturnType<typeof sendMutation.mutateAsync>>;
|
||||
let response: Awaited<ReturnType<typeof sendMessage>>;
|
||||
try {
|
||||
response = await sendMutation.mutateAsync(request);
|
||||
response = await sendMessage(request);
|
||||
} catch (error) {
|
||||
handleUsageLimitError(error);
|
||||
throw error;
|
||||
@@ -908,10 +913,10 @@ const AgentChatPage: FC = () => {
|
||||
};
|
||||
|
||||
const handleInterrupt = () => {
|
||||
if (!agentId || interruptMutation.isPending) {
|
||||
if (!agentId || isInterruptPending) {
|
||||
return;
|
||||
}
|
||||
void interruptMutation.mutateAsync();
|
||||
void interrupt();
|
||||
};
|
||||
|
||||
const handleDeleteQueuedMessage = async (id: number) => {
|
||||
@@ -920,7 +925,7 @@ const AgentChatPage: FC = () => {
|
||||
previousQueuedMessages.filter((message) => message.id !== id),
|
||||
);
|
||||
try {
|
||||
await deleteQueuedMutation.mutateAsync(id);
|
||||
await deleteQueuedMessage(id);
|
||||
} catch (error) {
|
||||
store.setQueuedMessages(previousQueuedMessages);
|
||||
throw error;
|
||||
@@ -941,7 +946,7 @@ const AgentChatPage: FC = () => {
|
||||
store.clearStreamError();
|
||||
store.setChatStatus("pending");
|
||||
try {
|
||||
const promotedMessage = await promoteQueuedMutation.mutateAsync(id);
|
||||
const promotedMessage = await promoteQueuedMessage(id);
|
||||
// Insert the promoted message into the store and cache
|
||||
// immediately so it appears in the timeline without
|
||||
// waiting for the WebSocket to deliver it.
|
||||
@@ -990,7 +995,8 @@ const AgentChatPage: FC = () => {
|
||||
? `ssh ${workspaceAgent.name}.${workspace.name}.${workspace.owner_name}.${sshConfigQuery.data.hostname_suffix}`
|
||||
: undefined;
|
||||
|
||||
const generateKeyMutation = useMutation({
|
||||
// See mutation destructuring comment above (React Compiler).
|
||||
const { mutate: generateKey } = useMutation({
|
||||
mutationFn: () => API.getApiKey(),
|
||||
});
|
||||
|
||||
@@ -1005,7 +1011,7 @@ const AgentChatPage: FC = () => {
|
||||
const repoRoots = Array.from(gitWatcher.repositories.keys()).sort();
|
||||
const folder = repoRoots[0] ?? workspaceAgent.expanded_directory;
|
||||
|
||||
generateKeyMutation.mutate(undefined, {
|
||||
generateKey(undefined, {
|
||||
onSuccess: ({ key }) => {
|
||||
location.href = getVSCodeHref(editor, {
|
||||
owner: workspace.owner_name,
|
||||
@@ -1141,7 +1147,7 @@ const AgentChatPage: FC = () => {
|
||||
compressionThreshold={compressionThreshold}
|
||||
isInputDisabled={isInputDisabled}
|
||||
isSubmissionPending={isSubmissionPending}
|
||||
isInterruptPending={interruptMutation.isPending}
|
||||
isInterruptPending={isInterruptPending}
|
||||
isSidebarCollapsed={isSidebarCollapsed}
|
||||
onToggleSidebarCollapsed={onToggleSidebarCollapsed}
|
||||
showSidebarPanel={showSidebarPanel}
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import { BookOpenIcon, LoaderIcon, TriangleAlertIcon } from "lucide-react";
|
||||
import type React from "react";
|
||||
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "#/components/Tooltip/Tooltip";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { Response } from "../Response";
|
||||
import { ToolCollapsible } from "./ToolCollapsible";
|
||||
import type { ToolStatus } from "./utils";
|
||||
|
||||
export const ReadSkillTool: React.FC<{
|
||||
label: string;
|
||||
body: string;
|
||||
status: ToolStatus;
|
||||
isError: boolean;
|
||||
errorMessage?: string;
|
||||
}> = ({ label, body, status, isError, errorMessage }) => {
|
||||
const hasContent = body.length > 0;
|
||||
const isRunning = status === "running";
|
||||
|
||||
return (
|
||||
<ToolCollapsible
|
||||
className="w-full"
|
||||
hasContent={hasContent}
|
||||
header={
|
||||
<>
|
||||
<BookOpenIcon className="h-4 w-4 shrink-0 text-content-secondary" />
|
||||
<span className={cn("text-sm", "text-content-secondary")}>
|
||||
{isRunning ? `Reading ${label}…` : `Read ${label}`}
|
||||
</span>
|
||||
{isError && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<TriangleAlertIcon className="h-3.5 w-3.5 shrink-0 text-content-secondary" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{errorMessage || "Failed to read skill"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
{isRunning && (
|
||||
<LoaderIcon className="h-3.5 w-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" />
|
||||
)}
|
||||
</>
|
||||
}
|
||||
>
|
||||
{body && (
|
||||
<ScrollArea
|
||||
className="mt-1.5 rounded-md border border-solid border-border-default"
|
||||
viewportClassName="max-h-64"
|
||||
scrollBarClassName="w-1.5"
|
||||
>
|
||||
<div className="px-3 py-2">
|
||||
<Response>{body}</Response>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
)}
|
||||
</ToolCollapsible>
|
||||
);
|
||||
};
|
||||
@@ -1342,6 +1342,12 @@ export const ReadSkillCompleted: Story = {
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
expect(canvas.getByText(/Read skill deep-review/)).toBeInTheDocument();
|
||||
// Expand the collapsible to verify markdown body renders.
|
||||
const toggle = canvas.getByRole("button");
|
||||
await userEvent.click(toggle);
|
||||
await waitFor(() => {
|
||||
expect(canvas.getByText("Deep Review Skill")).toBeInTheDocument();
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1393,6 +1399,12 @@ export const ReadSkillFileCompleted: Story = {
|
||||
expect(
|
||||
canvas.getByText(/Read deep-review\/roles\/security-reviewer\.md/),
|
||||
).toBeInTheDocument();
|
||||
// Expand the collapsible to verify markdown content renders.
|
||||
const toggle = canvas.getByRole("button");
|
||||
await userEvent.click(toggle);
|
||||
await waitFor(() => {
|
||||
expect(canvas.getByText("Security Reviewer Role")).toBeInTheDocument();
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import { ListTemplatesTool } from "./ListTemplatesTool";
|
||||
import { ProcessOutputTool } from "./ProcessOutputTool";
|
||||
import { ProposePlanTool } from "./ProposePlanTool";
|
||||
import { ReadFileTool } from "./ReadFileTool";
|
||||
import { ReadSkillTool } from "./ReadSkillTool";
|
||||
import { ReadTemplateTool } from "./ReadTemplateTool";
|
||||
import { SubagentTool } from "./SubagentTool";
|
||||
import { ToolCollapsible } from "./ToolCollapsible";
|
||||
@@ -210,6 +211,55 @@ const ReadFileRenderer: FC<ToolRendererProps> = ({
|
||||
);
|
||||
};
|
||||
|
||||
const ReadSkillRenderer: FC<ToolRendererProps> = ({
|
||||
status,
|
||||
args,
|
||||
result,
|
||||
isError,
|
||||
}) => {
|
||||
const parsedArgs = parseArgs(args);
|
||||
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
|
||||
const rec = asRecord(result);
|
||||
const body = rec ? asString(rec.body) : "";
|
||||
|
||||
return (
|
||||
<ReadSkillTool
|
||||
label={skillName ? `skill ${skillName}` : "skill"}
|
||||
body={body}
|
||||
status={status}
|
||||
isError={isError}
|
||||
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const ReadSkillFileRenderer: FC<ToolRendererProps> = ({
|
||||
status,
|
||||
args,
|
||||
result,
|
||||
isError,
|
||||
}) => {
|
||||
const parsedArgs = parseArgs(args);
|
||||
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
|
||||
const filePath = parsedArgs ? asString(parsedArgs.path) : "";
|
||||
const label =
|
||||
skillName && filePath
|
||||
? `${skillName}/${filePath}`
|
||||
: skillName || filePath || "skill file";
|
||||
const rec = asRecord(result);
|
||||
const content = rec ? asString(rec.content) : "";
|
||||
|
||||
return (
|
||||
<ReadSkillTool
|
||||
label={label}
|
||||
body={content}
|
||||
status={status}
|
||||
isError={isError}
|
||||
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const WriteFileRenderer: FC<ToolRendererProps> = ({
|
||||
status,
|
||||
args,
|
||||
@@ -667,6 +717,8 @@ const toolRenderers: Record<string, FC<ToolRendererProps>> = {
|
||||
create_workspace: CreateWorkspaceRenderer,
|
||||
list_templates: ListTemplatesRenderer,
|
||||
read_template: ReadTemplateRenderer,
|
||||
read_skill: ReadSkillRenderer,
|
||||
read_skill_file: ReadSkillFileRenderer,
|
||||
spawn_agent: SubagentRenderer,
|
||||
wait_agent: SubagentRenderer,
|
||||
message_agent: SubagentRenderer,
|
||||
|
||||
+4
-1
@@ -255,7 +255,7 @@ export const ProviderAccordionCards: Story = {
|
||||
expect(body.queryByText("OpenAI")).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.click(body.getByRole("button", { name: /OpenRouter/i }));
|
||||
await expect(body.getByLabelText("Base URL")).toBeInTheDocument();
|
||||
await expect(await body.findByLabelText("Base URL")).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
@@ -462,6 +462,9 @@ export const CreateAndUpdateProvider: Story = {
|
||||
await waitFor(() => {
|
||||
expect(args.onCreateProvider).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
|
||||
});
|
||||
expect(args.onCreateProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: "openai",
|
||||
|
||||
@@ -71,6 +71,7 @@ describe("AuditPage", () => {
|
||||
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog, MockAuditLog2],
|
||||
count: 2,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
// When
|
||||
@@ -90,6 +91,7 @@ describe("AuditPage", () => {
|
||||
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
await renderPage();
|
||||
@@ -114,6 +116,7 @@ describe("AuditPage", () => {
|
||||
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
await renderPage();
|
||||
@@ -140,9 +143,11 @@ describe("AuditPage", () => {
|
||||
|
||||
describe("Filtering", () => {
|
||||
it("filters by URL", async () => {
|
||||
const getAuditLogsSpy = vi
|
||||
.spyOn(API, "getAuditLogs")
|
||||
.mockResolvedValue({ audit_logs: [MockAuditLog], count: 1 });
|
||||
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
const query = "resource_type:workspace action:create";
|
||||
await renderPage({ filter: query });
|
||||
@@ -173,4 +178,29 @@ describe("AuditPage", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Capped count", () => {
|
||||
it("shows capped count indicator and navigates to next page with correct offset", async () => {
|
||||
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog, MockAuditLog2],
|
||||
count: 2001,
|
||||
count_cap: 2000,
|
||||
});
|
||||
|
||||
const user = userEvent.setup();
|
||||
await renderPage();
|
||||
|
||||
await screen.findByText(/2,000\+/);
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /next page/i }));
|
||||
|
||||
await waitFor(() =>
|
||||
expect(API.getAuditLogs).toHaveBeenLastCalledWith<[AuditLogsRequest]>({
|
||||
limit: DEFAULT_RECORDS_PER_PAGE,
|
||||
offset: DEFAULT_RECORDS_PER_PAGE,
|
||||
q: "",
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -69,6 +69,7 @@ describe("ConnectionLogPage", () => {
|
||||
MockDisconnectedSSHConnectionLog,
|
||||
],
|
||||
count: 2,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
// When
|
||||
@@ -95,6 +96,7 @@ describe("ConnectionLogPage", () => {
|
||||
.mockResolvedValue({
|
||||
connection_logs: [MockConnectedSSHConnectionLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
const query = "type:ssh status:ongoing";
|
||||
|
||||
+10
-1
@@ -732,7 +732,16 @@ func (l *peerLifecycle) setLostTimer(c *configMaps) {
|
||||
if l.lostTimer != nil {
|
||||
l.lostTimer.Stop()
|
||||
}
|
||||
ttl := lostTimeout - c.clock.Since(l.lastHandshake)
|
||||
var ttl time.Duration
|
||||
if l.lastHandshake.IsZero() {
|
||||
// Peer has never completed a handshake. Give it the full
|
||||
// lostTimeout to establish one rather than deleting it
|
||||
// immediately. A zero lastHandshake just means WireGuard
|
||||
// hasn't connected yet, not that the peer is gone.
|
||||
ttl = lostTimeout
|
||||
} else {
|
||||
ttl = lostTimeout - c.clock.Since(l.lastHandshake)
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = time.Nanosecond
|
||||
}
|
||||
|
||||
@@ -641,6 +641,97 @@ func TestConfigMaps_updatePeers_lost(t *testing.T) {
|
||||
_ = testutil.TryReceive(ctx, t, done)
|
||||
}
|
||||
|
||||
func TestConfigMaps_updatePeers_lost_zero_handshake(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t)
|
||||
fEng := newFakeEngineConfigurable()
|
||||
nodePrivateKey := key.NewNode()
|
||||
nodeID := tailcfg.NodeID(5)
|
||||
discoKey := key.NewDisco()
|
||||
uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public(), CoderDNSSuffixFQDN)
|
||||
defer uut.close()
|
||||
mClock := quartz.NewMock(t)
|
||||
uut.clock = mClock
|
||||
|
||||
p1ID := uuid.UUID{1}
|
||||
p1Node := newTestNode(1)
|
||||
p1n, err := NodeToProto(p1Node)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Respond to the status request from updatePeers(NODE) with no
|
||||
// handshake information, so lastHandshake stays zero.
|
||||
expectNoStatus := func() <-chan struct{} {
|
||||
called := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout waiting for status")
|
||||
return
|
||||
case b := <-fEng.status:
|
||||
_ = b // don't add any peer
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout sending done")
|
||||
case fEng.statusDone <- struct{}{}:
|
||||
close(called)
|
||||
}
|
||||
}()
|
||||
return called
|
||||
}
|
||||
|
||||
// Add the peer via NODE update — no handshake in status.
|
||||
s1 := expectNoStatus()
|
||||
updates := []*proto.CoordinateResponse_PeerUpdate{
|
||||
{
|
||||
Id: p1ID[:],
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
Node: p1n,
|
||||
},
|
||||
}
|
||||
uut.updatePeers(updates)
|
||||
nm := testutil.TryReceive(ctx, t, fEng.setNetworkMap)
|
||||
r := testutil.TryReceive(ctx, t, fEng.reconfig)
|
||||
require.Len(t, nm.Peers, 1)
|
||||
require.Len(t, r.wg.Peers, 1)
|
||||
_ = testutil.TryReceive(ctx, t, s1)
|
||||
|
||||
// Mark the peer as LOST, still with no handshake.
|
||||
s2 := expectNoStatus()
|
||||
updates[0].Kind = proto.CoordinateResponse_PeerUpdate_LOST
|
||||
updates[0].Node = nil
|
||||
uut.updatePeers(updates)
|
||||
_ = testutil.TryReceive(ctx, t, s2)
|
||||
|
||||
// Peer should NOT be removed immediately.
|
||||
select {
|
||||
case <-fEng.setNetworkMap:
|
||||
t.Fatal("should not reprogram")
|
||||
default:
|
||||
// OK!
|
||||
}
|
||||
|
||||
// Prepare a status response for when the lost timer fires after
|
||||
// lostTimeout. Return empty status (no handshake ever happened).
|
||||
s3 := expectNoStatus()
|
||||
mClock.Advance(lostTimeout).MustWait(ctx)
|
||||
_ = testutil.TryReceive(ctx, t, s3)
|
||||
|
||||
// Now the peer should be removed.
|
||||
nm = testutil.TryReceive(ctx, t, fEng.setNetworkMap)
|
||||
r = testutil.TryReceive(ctx, t, fEng.reconfig)
|
||||
require.Len(t, nm.Peers, 0)
|
||||
require.Len(t, r.wg.Peers, 0)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
uut.close()
|
||||
}()
|
||||
_ = testutil.TryReceive(ctx, t, done)
|
||||
}
|
||||
|
||||
func TestConfigMaps_updatePeers_lost_and_found(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
+50
-42
@@ -12,6 +12,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
@@ -463,6 +465,8 @@ type Conn struct {
|
||||
|
||||
trafficStats *connstats.Statistics
|
||||
lastNetInfo *tailcfg.NetInfo
|
||||
|
||||
awaitReachableGroup singleflight.Group
|
||||
}
|
||||
|
||||
func (c *Conn) GetNetInfo() *tailcfg.NetInfo {
|
||||
@@ -599,56 +603,60 @@ func (c *Conn) DERPMap() *tailcfg.DERPMap {
|
||||
// address is reachable. It's the callers responsibility to provide
|
||||
// a timeout, otherwise this function will block forever.
|
||||
func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel() // Cancel all pending pings on exit.
|
||||
result, _, _ := c.awaitReachableGroup.Do(ip.String(), func() (interface{}, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel() // Cancel all pending pings on exit.
|
||||
|
||||
completedCtx, completed := context.WithCancel(context.Background())
|
||||
defer completed()
|
||||
completedCtx, completed := context.WithCancel(context.Background())
|
||||
defer completed()
|
||||
|
||||
run := func() {
|
||||
// Safety timeout, initially we'll have around 10-20 goroutines
|
||||
// running in parallel. The exponential backoff will converge
|
||||
// around ~1 ping / 30s, this means we'll have around 10-20
|
||||
// goroutines pending towards the end as well.
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
run := func() {
|
||||
// Safety timeout, initially we'll have around 10-20 goroutines
|
||||
// running in parallel. The exponential backoff will converge
|
||||
// around ~1 ping / 30s, this means we'll have around 10-20
|
||||
// goroutines pending towards the end as well.
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// For reachability, we use TSMP ping, which pings at the IP layer, and
|
||||
// therefore requires that wireguard and the netstack are up. If we
|
||||
// don't wait for wireguard to be up, we could miss a handshake, and it
|
||||
// might take 5 seconds for the handshake to be retried. A 5s initial
|
||||
// round trip can set us up for poor TCP performance, since the initial
|
||||
// round-trip-time sets the initial retransmit timeout.
|
||||
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
|
||||
if err == nil {
|
||||
completed()
|
||||
// For reachability, we use TSMP ping, which pings at the IP layer,
|
||||
// and therefore requires that wireguard and the netstack are up.
|
||||
// If we don't wait for wireguard to be up, we could miss a
|
||||
// handshake, and it might take 5 seconds for the handshake to be
|
||||
// retried. A 5s initial round trip can set us up for poor TCP
|
||||
// performance, since the initial round-trip-time sets the initial
|
||||
// retransmit timeout.
|
||||
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
|
||||
if err == nil {
|
||||
completed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0
|
||||
eb.InitialInterval = 50 * time.Millisecond
|
||||
eb.MaxInterval = 30 * time.Second
|
||||
// Consume the first interval since
|
||||
// we'll fire off a ping immediately.
|
||||
_ = eb.NextBackOff()
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0
|
||||
eb.InitialInterval = 50 * time.Millisecond
|
||||
eb.MaxInterval = 5 * time.Second
|
||||
// Consume the first interval since
|
||||
// we'll fire off a ping immediately.
|
||||
_ = eb.NextBackOff()
|
||||
|
||||
t := backoff.NewTicker(eb)
|
||||
defer t.Stop()
|
||||
t := backoff.NewTicker(eb)
|
||||
defer t.Stop()
|
||||
|
||||
go run()
|
||||
for {
|
||||
select {
|
||||
case <-completedCtx.Done():
|
||||
return true
|
||||
case <-t.C:
|
||||
// Pings can take a while, so we can run multiple
|
||||
// in parallel to return ASAP.
|
||||
go run()
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
go run()
|
||||
for {
|
||||
select {
|
||||
case <-completedCtx.Done():
|
||||
return true, nil
|
||||
case <-t.C:
|
||||
// Pings can take a while, so we can run multiple
|
||||
// in parallel to return ASAP.
|
||||
go run()
|
||||
case <-ctx.Done():
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
return result.(bool)
|
||||
}
|
||||
|
||||
// Closed is a channel that ends when the connection has
|
||||
|
||||
Reference in New Issue
Block a user