Compare commits

...

25 Commits

Author SHA1 Message Date
Jon Ayers e85be9f42e fix(tailnet): give peers with no handshake full lostTimeout before removal 2026-04-10 05:43:18 +00:00
Jon Ayers cfd7730194 chore(enterprise/tailnet): add debug logging for LOST and DISCONNECTED peer updates 2026-04-10 00:40:57 +00:00
Jon Ayers 1937ada0cd fix: use enriched logger for HeartbeatClose, reduce AwaitReachable backoff to 5s 2026-04-09 23:45:59 +00:00
Jon Ayers d64cd6415d revert: move HeartbeatClose back before agent dial 2026-04-09 22:37:31 +00:00
Jon Ayers c1851d9453 chore(coderd/workspaceapps): add workspace_id and elapsed time to PTY dial logs 2026-04-09 22:35:44 +00:00
Jon Ayers 8f73453681 fix(coderd/workspaceapps): move HeartbeatClose after agent dial, add 1m setup timeout 2026-04-09 21:39:28 +00:00
Jon Ayers 165db3d31c perf(enterprise/tailnet): increase coordinator worker counts and batch size for 10k scale 2026-04-09 21:26:52 +00:00
Jon Ayers 1bd1516fd1 perf(tailnet): singleflight AwaitReachable to deduplicate concurrent ping storms 2026-04-09 19:34:22 +00:00
Jon Ayers 81ba35a987 fix(coderd/tailnet): move ensureAgent Send outside mutex using singleflight 2026-04-09 19:12:27 +00:00
Jon Ayers 53d63cf8e9 perf(coderd/database/pubsub): batch-drain msgQueue to amortize lock overhead
Replace the one-at-a-time dequeue loop in msgQueue.run() with a batch
drain that copies up to 256 messages per lock acquisition. This
amortizes mutex acquire/release and cond.Wait costs across many
messages, improving drain throughput during bursts and reducing the
likelihood of ring buffer overflow.
2026-04-08 00:02:29 +00:00
Jon Ayers 4213a43b53 fix(enterprise/tailnet): async singleflight-coalesced resyncPeerMappings in pubsub callbacks
Replace synchronous resyncPeerMappings() calls in listenPeer and
listenTunnel with async goroutines using singleflight.Do. This
prevents blocking the pubsub drain goroutine when ErrDroppedMessages
arrives, avoiding cascading buffer overflows.
2026-04-08 00:02:21 +00:00
Garrett Delfosse 5453a6c6d6 fix(scripts/releaser): simplify branch regex and fix changelog range (#23947)
Two fixes for the release script:

**1. Branch regex cleanup** — Simplified to only match `release/X.Y`.
Removed
support for `release/X.Y.Z` and `release/X.Y-rc.N` branch formats. RCs
are
now tagged from main (not from release branches), and the three-segment
`release/X.Y.Z` format will not be used going forward.

**2. Changelog range for first release on a new minor** — When no tags
match
the branch's major.minor, the commit range fell back to `HEAD` (entire
git
history, ~13k lines of changelog). Now computes `git merge-base` with
the
previous minor's release branch (e.g. `origin/release/2.32`) as the
changelog
starting point. This works even when that branch has no tags pushed yet.
Falls
back to the latest reachable tag from a previous minor if the branch
doesn't
exist.
2026-04-07 17:07:21 +00:00
Jake Howell 21c08a37d7 feat: de-mui <LogLine /> and <Logs /> (#24043)
Migrated LogLine and Logs components from Emotion CSS-in-JS to Tailwind
CSS classes.

- Replaced Emotion `css` prop and theme-based styling with Tailwind
utility classes in `LogLine` and `LogLinePrefix` components
- Converted CSS-in-JS styles object to conditional Tailwind classes
using the `cn` utility function
- Updated log level styling (error, debug, warn) to use Tailwind classes
with design token references
- Migrated the Logs container component styling from Emotion to Tailwind
classes
- Removed Emotion imports and theme dependencies
2026-04-07 16:35:10 +00:00
Jake Howell 2bd261fbbf fix: cleanup useKebabMenu code (#24042)
Refactored the tab overflow hook by renaming `useTabOverflowKebabMenu`
to `useKebabMenu` and removing the configurable `alwaysVisibleTabsCount`
parameter.

- Renamed `useTabOverflowKebabMenu` to `useKebabMenu` and moved it to a
new file
- Removed the `alwaysVisibleTabsCount` parameter and hardcoded it to 1
tab as `ALWAYS_VISIBLE_TABS_COUNT`
- Removed the `utils/index.ts` export file for the Tabs component
- Updated the import in `AgentRow.tsx` to use the new hook name and
removed the `alwaysVisibleTabsCount` prop
- Refactored the internal logic to use a more functional approach with
`reduce` instead of imperative loops
- Added better performance optimizations to prevent unnecessary
re-renders
2026-04-08 02:25:18 +10:00
Kyle Carberry cffc68df58 feat(site): render read_skill body as markdown (#24069) 2026-04-07 11:50:21 -04:00
Jake Howell 6e5335df1e feat: implement new workspace download logs dropdown (#23963)
This PR improves the agent log download functionality by replacing the
single download button with a comprehensive dropdown menu system.

- Replaced single download button with a dropdown menu offering multiple
download options
- Added ability to download all logs or individual log sources
separately
- Updated download button to show chevron icon indicating dropdown
functionality
- Enhanced download options with appropriate icons for each log source

<img width="370" height="305" alt="image"
src="https://github.com/user-attachments/assets/ddf025f5-f936-499a-9165-6e81b62d6860"
/>
2026-04-07 15:27:43 +00:00
Kyle Carberry 16265e834e chore: update fantasy fork to use github.com/coder/fantasy (#24100)
Moves the `charm.land/fantasy` replace directive from
`github.com/kylecarbs/fantasy` to `github.com/coder/fantasy`, pointing
at the same `cj/go1.25` branch and commit (`112927d9b6d8`).

> Generated by Coder Agents
2026-04-07 16:11:49 +01:00
Zach 565a15bc9b feat: update user secrets queries for REST API and injection (#23998)
Update queries as prep work for user secrets API development:
- Switch all lookups and mutations from ID-based to user_id + name
- Split list query into metadata-only (for API responses) and
with-values (for provisioner/agent)
- Add partial update support using CASE WHEN pattern for write-only
value fields
- Include value_key_id in create for dbcrypt encryption support
- Update dbauthz wrappers and remove stale methods from dbmetrics
2026-04-07 09:03:28 -06:00
Ethan 76a2cb1af5 fix(site/src/pages/AgentsPage): reset provider form after create (#23975)
Previously, after creating a provider config in the agents provider
editor, the Save changes button stayed enabled for the lifetime of the
mounted form. The form kept the pre-create local baseline, so the
freshly-saved values still looked dirty.

Key `ProviderForm` by provider config identity so React remounts the
form when a config is created and re-establishes the pristine state from
the saved provider values.
2026-04-08 00:32:36 +10:00
Kyle Carberry 684f21740d perf(coderd): batch chat heartbeat queries into single UPDATE per interval (#24037)
## Summary

Replaces N per-chat heartbeat goroutines with a single centralized
heartbeat loop that issues one `UPDATE` per 30s interval for all running
chats on a worker.

## Problem

Each running chat spawned a dedicated goroutine that issued an
individual `UPDATE chats SET heartbeat_at = NOW() WHERE id = $1 AND
worker_id = $2 AND status = 'running'` query every 30 seconds. At 10,000
concurrent chats this produces **~333 DB queries/second** just for
heartbeats, plus ~333 `ActivityBumpWorkspace` CTE queries/second from
`trackWorkspaceUsage`.

## Solution

New `UpdateChatHeartbeats` (plural) SQL query replaces the old singular
`UpdateChatHeartbeat`:

```sql
UPDATE chats
SET    heartbeat_at = @now::timestamptz
WHERE  worker_id = @worker_id::uuid
  AND  status = 'running'::chat_status
RETURNING id;
```

A single `heartbeatLoop` goroutine on the `Server`:
1. Ticks every `chatHeartbeatInterval` (30s)
2. Issues one batch UPDATE for all registered chats
3. Detects stolen/completed chats via set-difference (equivalent of old
`rows == 0`)
4. Calls `trackWorkspaceUsage` for surviving chats

`processChat` registers an entry in the heartbeat registry instead of
spawning a goroutine.

## Impact

| Metric | Before (10K chats) | After (10K chats) |
|---|---|---|
| Heartbeat queries/sec | ~333 | ~0.03 (1 per 30s per replica) |
| Heartbeat goroutines | 10,000 | 1 |
| Self-interrupt detection | Per-chat `rows==0` | Batch set-difference |

---

> 🤖 Generated by Coder Agents

<details><summary>Implementation notes</summary>

- Uses `@now` parameter instead of `NOW()` so tests with `quartz.Mock`
can control timestamps.
- `heartbeatEntry` stores `context.CancelCauseFunc` + workspace state
for the centralized loop.
- `recoverStaleChats` is unaffected — it reads `heartbeat_at` which is
still updated.
- The old singular `UpdateChatHeartbeat` is removed entirely.
- `dbauthz` wrapper uses system-level `rbac.ResourceChat` authorization
(same pattern as `AcquireChats`).

</details>
2026-04-07 10:25:46 -04:00
George K 86ca61d6ca perf: cap count queries and emit native UUID comparisons for audit/connection logs (#23835)
Audit and connection log pages were timing out due to expensive COUNT(*)
queries over large tables. This commit adds opt-in count capping: requests can
return a `count_cap` field signaling that the count was truncated at a threshold,
avoiding full table scans that caused page timeouts.

Text-cast UUID comparisons in regosql-generated authorization queries
also contributed to the slowdown by preventing index usage for connection
and audit log queries. These now emit native UUID operators.

Frontend changes handle the capped state in usePaginatedQuery and
PaginationWidget, optionally displaying a capped count in the pagination
UI (e.g. "Showing 2,076 to 2,100 of 2,000+ logs")

Related to:
https://linear.app/codercom/issue/PLAT-31/connectionaudit-log-performance-issue
2026-04-07 07:24:53 -07:00
Jake Howell f0521cfa3c fix: resolve <LogLine /> storybook flake (#24084)
This pull-request ensures we have a stable test where the content
doesn't change every time we have a new storybook artifact by setting it
to a consistent date.

Closes https://github.com/coder/internal/issues/1454
2026-04-08 00:17:06 +10:00
Danielle Maywood 0c5d189aff fix(site): stabilize mutation callbacks for React Compiler memoization (#24089) 2026-04-07 15:05:27 +01:00
Michael Suchacz d7c8213eee fix(coderd/x/chatd/mcpclient): deterministic external MCP tool ordering (#24075)
> This PR was authored by Mux on behalf of Mike.

External MCP tools returned by `ConnectAll` were ordered by goroutine
completion, making the tool list nondeterministic across chat turns.
This broke prompt-cache stability since tools are serialized in order.

Sort tools by their model-visible name after all connections complete,
matching the existing pattern in workspace MCP tools
(`agent/x/agentmcp/manager.go`). Also guards against a nil-client panic
in cleanup when a connected server contributes zero tools after
filtering.
2026-04-07 14:42:30 +02:00
Cian Johnston 63924ac687 fix(site): use async findByLabelText in ProviderAccordionCards story (#24087)
- Use async `findByLabelText` instead of sync `getByLabelText` in
`ProviderAccordionCards` story
- Same bug fixed in #23999 for three other stories but missed for this
one

> 🤖 Written by a Coder Agent. Will be reviewed by a human.
2026-04-07 14:13:56 +02:00
65 changed files with 2551 additions and 1245 deletions
+6
View File
@@ -14175,6 +14175,9 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -14496,6 +14499,9 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
+6
View File
@@ -12739,6 +12739,9 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -13039,6 +13042,9 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
+8 -1
View File
@@ -26,6 +26,11 @@ import (
"github.com/coder/coder/v2/codersdk"
)
// Limit the count query to avoid a slow sequential scan due to joins
// on a large table. Set to 0 to disable capping (but also see the note
// in the SQL query).
const auditLogCountCap = 2000
// @Summary Get audit logs
// @ID get-audit-logs
// @Security CoderSessionToken
@@ -66,7 +71,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
countFilter.Username = ""
}
// Use the same filters to count the number of audit logs
countFilter.CountCap = auditLogCountCap
count, err := api.Database.CountAuditLogs(ctx, countFilter)
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
@@ -81,6 +86,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: []codersdk.AuditLog{},
Count: 0,
CountCap: auditLogCountCap,
})
return
}
@@ -98,6 +104,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: api.convertAuditLogs(ctx, dblogs),
Count: count,
CountCap: auditLogCountCap,
})
}
+27 -40
View File
@@ -2155,17 +2155,12 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
return q.db.DeleteUserChatProviderKey(ctx, arg)
}
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// First get the secret to check ownership
secret, err := q.GetUserSecret(ctx, id)
if err != nil {
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
return err
}
return q.db.DeleteUserSecret(ctx, id)
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
}
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
@@ -4128,19 +4123,6 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
return q.db.GetUserNotificationPreferences(ctx, userID)
}
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
// First get the secret to check ownership
secret, err := q.db.GetUserSecret(ctx, id)
if err != nil {
return database.UserSecret{}, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
return database.UserSecret{}, err
}
return secret, nil
}
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
@@ -5524,7 +5506,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
return q.db.ListUserChatCompactionThresholds(ctx, userID)
}
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
return nil, err
@@ -5532,6 +5514,16 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
return q.db.ListUserSecrets(ctx, userID)
}
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
// This query returns decrypted secret values and must only be called
// from system contexts (provisioner, agent manifest). REST API
// handlers should use ListUserSecrets (metadata only).
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.ListUserSecretsWithValues(ctx, userID)
}
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
if err != nil {
@@ -5782,15 +5774,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
return q.db.UpdateChatByID(ctx, arg)
}
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return 0, err
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
// The batch heartbeat is a system-level operation filtered by
// worker_id. Authorization is enforced by the AsChatd context
// at the call site rather than per-row, because checking each
// row individually would defeat the purpose of batching.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeat(ctx, arg)
return q.db.UpdateChatHeartbeats(ctx, arg)
}
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
@@ -6632,17 +6624,12 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
return q.db.UpdateUserRoles(ctx, arg)
}
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
// First get the secret to check ownership
secret, err := q.db.GetUserSecret(ctx, arg.ID)
if err != nil {
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
return database.UserSecret{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
return database.UserSecret{}, err
}
return q.db.UpdateUserSecret(ctx, arg)
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
}
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
+29 -29
View File
@@ -842,15 +842,15 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatHeartbeatParams{
ID: chat.ID,
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
resultID := uuid.New()
arg := database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{resultID},
WorkerID: uuid.New(),
Now: time.Now(),
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
}))
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -5346,19 +5346,20 @@ func (s *MethodTestSuite) TestUserSecrets() {
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
Returns(secret)
}))
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
check.Args(secret.ID).
Asserts(secret, policy.ActionRead).
Returns(secret)
}))
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
check.Args(user.ID).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
Returns([]database.ListUserSecretsRow{row})
}))
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
check.Args(user.ID).
Asserts(rbac.ResourceSystem, policy.ActionRead).
Returns([]database.UserSecret{secret})
}))
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
@@ -5370,22 +5371,21 @@ func (s *MethodTestSuite) TestUserSecrets() {
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
Returns(ret)
}))
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
arg := database.UpdateUserSecretParams{ID: secret.ID}
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
check.Args(arg).
Asserts(secret, policy.ActionUpdate).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
Returns(updated)
}))
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
check.Args(secret.ID).
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
Returns()
}))
}
+1
View File
@@ -1597,6 +1597,7 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
Name: takeFirst(seed.Name, "secret-name"),
Description: takeFirst(seed.Description, "secret description"),
Value: takeFirst(seed.Value, "secret value"),
ValueKeyID: seed.ValueKeyID,
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
})
+21 -21
View File
@@ -712,11 +712,11 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
return r0
}
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
start := time.Now()
r0 := m.s.DeleteUserSecret(ctx, id)
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
return r0
}
@@ -2624,14 +2624,6 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
return r0, r1
}
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.GetUserSecret(ctx, id)
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
@@ -3920,7 +3912,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
start := time.Now()
r0, r1 := m.s.ListUserSecrets(ctx, userID)
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
@@ -3928,6 +3920,14 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
return r0, r1
}
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
return r0, r1
}
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
start := time.Now()
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
@@ -4136,11 +4136,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
return r0, r1
}
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
return r0, r1
}
@@ -4696,11 +4696,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
return r0, r1
}
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
return r0, r1
}
+36 -36
View File
@@ -1199,18 +1199,18 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
}
// DeleteUserSecret mocks base method.
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// DeleteUserSecretByUserIDAndName mocks base method.
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
}
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
@@ -4907,21 +4907,6 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
}
// GetUserSecret mocks base method.
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
ret0, _ := ret[0].(database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserSecret indicates an expected call of GetUserSecret.
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
}
// GetUserSecretByUserIDAndName mocks base method.
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
@@ -7412,10 +7397,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
}
// ListUserSecrets mocks base method.
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
ret0, _ := ret[0].([]database.UserSecret)
ret0, _ := ret[0].([]database.ListUserSecretsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -7426,6 +7411,21 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
}
// ListUserSecretsWithValues mocks base method.
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
ret0, _ := ret[0].([]database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
}
// ListWorkspaceAgentPortShares mocks base method.
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
m.ctrl.T.Helper()
@@ -7835,19 +7835,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
}
// UpdateChatHeartbeat mocks base method.
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
// UpdateChatHeartbeats mocks base method.
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
ret0, _ := ret[0].(int64)
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
}
// UpdateChatLabelsByID mocks base method.
@@ -8854,19 +8854,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
}
// UpdateUserSecret mocks base method.
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
// UpdateUserSecretByUserIDAndName mocks base method.
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
ret0, _ := ret[0].(database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
}
// UpdateUserStatus mocks base method.
+2
View File
@@ -584,6 +584,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
arg.DateTo,
arg.BuildReason,
arg.RequestID,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -720,6 +721,7 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
arg.WorkspaceID,
arg.ConnectionID,
arg.Status,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -145,5 +145,13 @@ func extractWhereClause(query string) string {
// Remove SQL comments
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
// Normalize indentation so subquery wrapping doesn't cause
// mismatches.
lines := strings.Split(whereClause, "\n")
for i, line := range lines {
lines[i] = strings.TrimLeft(line, " \t")
}
whereClause = strings.Join(lines, "\n")
return strings.TrimSpace(whereClause)
}
+27 -17
View File
@@ -81,8 +81,8 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
}
func (q *msgQueue) run() {
var batch [maxDrainBatch]msgOrErr
for {
// wait until there is something on the queue or we are closed
q.cond.L.Lock()
for q.size == 0 && !q.closed {
q.cond.Wait()
@@ -91,28 +91,32 @@ func (q *msgQueue) run() {
q.cond.L.Unlock()
return
}
item := q.q[q.front]
q.front = (q.front + 1) % BufferSize
q.size--
// Drain up to maxDrainBatch items while holding the lock.
n := min(q.size, maxDrainBatch)
for i := range n {
batch[i] = q.q[q.front]
q.front = (q.front + 1) % BufferSize
}
q.size -= n
q.cond.L.Unlock()
// process item without holding lock
if item.err == nil {
// real message
if q.l != nil {
q.l(q.ctx, item.msg)
// Dispatch each message individually without holding the lock.
for i := range n {
item := batch[i]
if item.err == nil {
if q.l != nil {
q.l(q.ctx, item.msg)
continue
}
if q.le != nil {
q.le(q.ctx, item.msg, nil)
continue
}
continue
}
if q.le != nil {
q.le(q.ctx, item.msg, nil)
continue
q.le(q.ctx, nil, item.err)
}
// unhittable
continue
}
// if the listener wants errors, send it.
if q.le != nil {
q.le(q.ctx, nil, item.err)
}
}
}
@@ -233,6 +237,12 @@ type PGPubsub struct {
// for a subscriber before dropping messages.
const BufferSize = 2048
// maxDrainBatch is the maximum number of messages to drain from the ring
// buffer per iteration. Batching amortizes the cost of mutex
// acquire/release and cond.Wait across many messages, improving drain
// throughput during bursts.
const maxDrainBatch = 256
// Subscribe calls the listener when an event matching the name is received.
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
+14 -7
View File
@@ -152,7 +152,7 @@ type sqlcQuerier interface {
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
@@ -598,7 +598,6 @@ type sqlcQuerier interface {
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
// GetUserStatusCounts returns the count of users in each status over time.
// The time range is inclusively defined by the start_time and end_time parameters.
@@ -818,7 +817,13 @@ type sqlcQuerier interface {
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
// Returns metadata only (no value or value_key_id) for the
// REST API list and get endpoints.
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
// Returns all columns including the secret value. Used by the
// provisioner (build-time injection) and the agent manifest
// (runtime injection).
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
@@ -870,9 +875,11 @@ type sqlcQuerier interface {
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
// caller can detect stolen or completed chats via set-difference.
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
// Updates the cached injected context parts (AGENTS.md +
// skills) on the chat row. Called only when context changes
@@ -955,7 +962,7 @@ type sqlcQuerier interface {
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
+45 -30
View File
@@ -7339,13 +7339,7 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, secretID, createdSecret.ID)
// 2. READ by ID
readSecret, err := db.GetUserSecret(ctx, createdSecret.ID)
require.NoError(t, err)
assert.Equal(t, createdSecret.ID, readSecret.ID)
assert.Equal(t, "workflow-secret", readSecret.Name)
// 3. READ by UserID and Name
// 2. READ by UserID and Name
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
@@ -7353,33 +7347,43 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
require.NoError(t, err)
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
// 4. LIST
// 3. LIST (metadata only)
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
require.NoError(t, err)
require.Len(t, secrets, 1)
assert.Equal(t, createdSecret.ID, secrets[0].ID)
// 5. UPDATE
updateParams := database.UpdateUserSecretParams{
ID: createdSecret.ID,
Description: "Updated workflow description",
Value: "updated-workflow-value",
EnvName: "UPDATED_WORKFLOW_ENV",
FilePath: "/updated/workflow/path",
// 4. LIST with values
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
require.NoError(t, err)
require.Len(t, secretsWithValues, 1)
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
// 5. UPDATE (partial - only description)
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
UpdateDescription: true,
Description: "Updated workflow description",
}
updatedSecret, err := db.UpdateUserSecret(ctx, updateParams)
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
require.NoError(t, err)
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
assert.Equal(t, "updated-workflow-value", updatedSecret.Value)
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
// 6. DELETE
err = db.DeleteUserSecret(ctx, createdSecret.ID)
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
})
require.NoError(t, err)
// Verify deletion
_, err = db.GetUserSecret(ctx, createdSecret.ID)
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
require.Error(t, err)
assert.Contains(t, err.Error(), "no rows in result set")
@@ -7449,9 +7453,13 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
})
// Verify both secrets exist
_, err = db.GetUserSecret(ctx, secret1.ID)
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID, Name: secret1.Name,
})
require.NoError(t, err)
_, err = db.GetUserSecret(ctx, secret2.ID)
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID, Name: secret2.Name,
})
require.NoError(t, err)
})
}
@@ -7474,14 +7482,14 @@ func TestUserSecretsAuthorization(t *testing.T) {
org := dbgen.Organization(t, db, database.Organization{})
// Create secrets for users
user1Secret := dbgen.UserSecret(t, db, database.UserSecret{
_ = dbgen.UserSecret(t, db, database.UserSecret{
UserID: user1.ID,
Name: "user1-secret",
Description: "User 1's secret",
Value: "user1-value",
})
user2Secret := dbgen.UserSecret(t, db, database.UserSecret{
_ = dbgen.UserSecret(t, db, database.UserSecret{
UserID: user2.ID,
Name: "user2-secret",
Description: "User 2's secret",
@@ -7491,7 +7499,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
testCases := []struct {
name string
subject rbac.Subject
secretID uuid.UUID
lookupUserID uuid.UUID
lookupName string
expectedAccess bool
}{
{
@@ -7501,7 +7510,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
Scope: rbac.ScopeAll,
},
secretID: user1Secret.ID,
lookupUserID: user1.ID,
lookupName: "user1-secret",
expectedAccess: true,
},
{
@@ -7511,7 +7521,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
Scope: rbac.ScopeAll,
},
secretID: user2Secret.ID,
lookupUserID: user2.ID,
lookupName: "user2-secret",
expectedAccess: false,
},
{
@@ -7521,7 +7532,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
Scope: rbac.ScopeAll,
},
secretID: user1Secret.ID,
lookupUserID: user1.ID,
lookupName: "user1-secret",
expectedAccess: false,
},
{
@@ -7531,7 +7543,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
Scope: rbac.ScopeAll,
},
secretID: user1Secret.ID,
lookupUserID: user1.ID,
lookupName: "user1-secret",
expectedAccess: false,
},
}
@@ -7543,8 +7556,10 @@ func TestUserSecretsAuthorization(t *testing.T) {
authCtx := dbauthz.As(ctx, tc.subject)
// Test GetUserSecret
_, err := authDB.GetUserSecret(authCtx, tc.secretID)
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
UserID: tc.lookupUserID,
Name: tc.lookupName,
})
if tc.expectedAccess {
require.NoError(t, err, "expected access to be granted")
+362 -259
View File
@@ -2275,93 +2275,105 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
}
const countAuditLogs = `-- name: CountAuditLogs :one
SELECT COUNT(*)
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN $1::text != '' THEN resource_type = $1::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
ELSE true
END
-- Filter organization_id
AND CASE
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN $4::text != '' THEN resource_target = $4
ELSE true
END
-- Filter action
AND CASE
WHEN $5::text != '' THEN action = $5::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower($7)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8::text != '' THEN users.email = $8
ELSE true
END
-- Filter by date_from
AND CASE
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
ELSE true
END
-- Filter by date_to
AND CASE
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
ELSE true
END
-- Filter request_id
AND CASE
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
SELECT COUNT(*) FROM (
SELECT 1
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN $1::text != '' THEN resource_type = $1::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
ELSE true
END
-- Filter organization_id
AND CASE
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN $4::text != '' THEN resource_target = $4
ELSE true
END
-- Filter action
AND CASE
WHEN $5::text != '' THEN action = $5::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower($7)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8::text != '' THEN users.email = $8
ELSE true
END
-- Filter by date_from
AND CASE
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
ELSE true
END
-- Filter by date_to
AND CASE
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
ELSE true
END
-- Filter request_id
AND CASE
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
-- Avoid a slow scan on a large table with joins. The caller
-- passes the count cap and we add 1 so the frontend can detect
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
-- -> NULL + 1 = NULL).
-- NOTE: Parameterizing this so that we can easily change from,
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
-- here if disabling the capping on a large table permanently.
-- This way the PG planner can plan parallel execution for
-- potential large wins.
LIMIT NULLIF($13::int, 0) + 1
) AS limited_count
`
type CountAuditLogsParams struct {
@@ -2377,6 +2389,7 @@ type CountAuditLogsParams struct {
DateTo time.Time `db:"date_to" json:"date_to"`
BuildReason string `db:"build_reason" json:"build_reason"`
RequestID uuid.UUID `db:"request_id" json:"request_id"`
CountCap int32 `db:"count_cap" json:"count_cap"`
}
func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) {
@@ -2393,6 +2406,7 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam
arg.DateTo,
arg.BuildReason,
arg.RequestID,
arg.CountCap,
)
var count int64
err := row.Scan(&count)
@@ -6601,30 +6615,49 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
return i, err
}
const updateChatHeartbeat = `-- name: UpdateChatHeartbeat :execrows
const updateChatHeartbeats = `-- name: UpdateChatHeartbeats :many
UPDATE
chats
SET
heartbeat_at = NOW()
heartbeat_at = $1::timestamptz
WHERE
id = $1::uuid
AND worker_id = $2::uuid
id = ANY($2::uuid[])
AND worker_id = $3::uuid
AND status = 'running'::chat_status
RETURNING id
`
type UpdateChatHeartbeatParams struct {
ID uuid.UUID `db:"id" json:"id"`
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
type UpdateChatHeartbeatsParams struct {
Now time.Time `db:"now" json:"now"`
IDs []uuid.UUID `db:"ids" json:"ids"`
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
}
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) {
result, err := q.db.ExecContext(ctx, updateChatHeartbeat, arg.ID, arg.WorkerID)
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
// caller can detect stolen or completed chats via set-difference.
func (q *sqlQuerier) UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats, arg.Now, pq.Array(arg.IDs), arg.WorkerID)
if err != nil {
return 0, err
return nil, err
}
return result.RowsAffected()
defer rows.Close()
var items []uuid.UUID
for rows.Next() {
var id uuid.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, id)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateChatLabelsByID = `-- name: UpdateChatLabelsByID :one
@@ -7571,110 +7604,113 @@ func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUps
}
const countConnectionLogs = `-- name: CountConnectionLogs :one
SELECT
COUNT(*) AS count
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = $1
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN $2 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower($2) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = $3
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN $4 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = $4 AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN $5 :: text != '' THEN
type = $5 :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7 :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower($7) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8 :: text != '' THEN
users.email = $8
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= $9
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= $10
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = $11
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = $12
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN $13 :: text != '' THEN
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
SELECT COUNT(*) AS count FROM (
SELECT 1
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = $1
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN $2 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower($2) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = $3
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN $4 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = $4 AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN $5 :: text != '' THEN
type = $5 :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7 :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower($7) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8 :: text != '' THEN
users.email = $8
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= $9
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= $10
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = $11
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = $12
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN $13 :: text != '' THEN
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
-- NOTE: See the CountAuditLogs LIMIT note.
LIMIT NULLIF($14::int, 0) + 1
) AS limited_count
`
type CountConnectionLogsParams struct {
@@ -7691,6 +7727,7 @@ type CountConnectionLogsParams struct {
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"`
Status string `db:"status" json:"status"`
CountCap int32 `db:"count_cap" json:"count_cap"`
}
func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) {
@@ -7708,6 +7745,7 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio
arg.WorkspaceID,
arg.ConnectionID,
arg.Status,
arg.CountCap,
)
var count int64
err := row.Scan(&count)
@@ -22601,21 +22639,30 @@ INSERT INTO user_secrets (
name,
description,
value,
value_key_id,
env_name,
file_path
) VALUES (
$1, $2, $3, $4, $5, $6, $7
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8
) RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
`
type CreateUserSecretParams struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
}
func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretParams) (UserSecret, error) {
@@ -22625,6 +22672,7 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
arg.Name,
arg.Description,
arg.Value,
arg.ValueKeyID,
arg.EnvName,
arg.FilePath,
)
@@ -22644,41 +22692,24 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
return i, err
}
const deleteUserSecret = `-- name: DeleteUserSecret :exec
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :exec
DELETE FROM user_secrets
WHERE id = $1
WHERE user_id = $1 AND name = $2
`
func (q *sqlQuerier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteUserSecret, id)
type DeleteUserSecretByUserIDAndNameParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
}
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error {
_, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
return err
}
const getUserSecret = `-- name: GetUserSecret :one
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
WHERE id = $1
`
func (q *sqlQuerier) GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error) {
row := q.db.QueryRowContext(ctx, getUserSecret, id)
var i UserSecret
err := row.Scan(
&i.ID,
&i.UserID,
&i.Name,
&i.Description,
&i.Value,
&i.EnvName,
&i.FilePath,
&i.CreatedAt,
&i.UpdatedAt,
&i.ValueKeyID,
)
return i, err
}
const getUserSecretByUserIDAndName = `-- name: GetUserSecretByUserIDAndName :one
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
FROM user_secrets
WHERE user_id = $1 AND name = $2
`
@@ -22706,17 +22737,76 @@ func (q *sqlQuerier) GetUserSecretByUserIDAndName(ctx context.Context, arg GetUs
}
const listUserSecrets = `-- name: ListUserSecrets :many
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id FROM user_secrets
SELECT
id, user_id, name, description,
env_name, file_path,
created_at, updated_at
FROM user_secrets
WHERE user_id = $1
ORDER BY name ASC
`
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
type ListUserSecretsRow struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// Returns metadata only (no value or value_key_id) for the
// REST API list and get endpoints.
func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error) {
rows, err := q.db.QueryContext(ctx, listUserSecrets, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListUserSecretsRow
for rows.Next() {
var i ListUserSecretsRow
if err := rows.Scan(
&i.ID,
&i.UserID,
&i.Name,
&i.Description,
&i.EnvName,
&i.FilePath,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listUserSecretsWithValues = `-- name: ListUserSecretsWithValues :many
SELECT id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
FROM user_secrets
WHERE user_id = $1
ORDER BY name ASC
`
// Returns all columns including the secret value. Used by the
// provisioner (build-time injection) and the agent manifest
// (runtime injection).
func (q *sqlQuerier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error) {
rows, err := q.db.QueryContext(ctx, listUserSecretsWithValues, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []UserSecret
for rows.Next() {
var i UserSecret
@@ -22745,33 +22835,46 @@ func (q *sqlQuerier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]U
return items, nil
}
const updateUserSecret = `-- name: UpdateUserSecret :one
const updateUserSecretByUserIDAndName = `-- name: UpdateUserSecretByUserIDAndName :one
UPDATE user_secrets
SET
description = $2,
value = $3,
env_name = $4,
file_path = $5,
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
value = CASE WHEN $1::bool THEN $2 ELSE value END,
value_key_id = CASE WHEN $1::bool THEN $3 ELSE value_key_id END,
description = CASE WHEN $4::bool THEN $5 ELSE description END,
env_name = CASE WHEN $6::bool THEN $7 ELSE env_name END,
file_path = CASE WHEN $8::bool THEN $9 ELSE file_path END,
updated_at = CURRENT_TIMESTAMP
WHERE user_id = $10 AND name = $11
RETURNING id, user_id, name, description, value, env_name, file_path, created_at, updated_at, value_key_id
`
type UpdateUserSecretParams struct {
ID uuid.UUID `db:"id" json:"id"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
type UpdateUserSecretByUserIDAndNameParams struct {
UpdateValue bool `db:"update_value" json:"update_value"`
Value string `db:"value" json:"value"`
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
UpdateDescription bool `db:"update_description" json:"update_description"`
Description string `db:"description" json:"description"`
UpdateEnvName bool `db:"update_env_name" json:"update_env_name"`
EnvName string `db:"env_name" json:"env_name"`
UpdateFilePath bool `db:"update_file_path" json:"update_file_path"`
FilePath string `db:"file_path" json:"file_path"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
}
func (q *sqlQuerier) UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error) {
row := q.db.QueryRowContext(ctx, updateUserSecret,
arg.ID,
arg.Description,
func (q *sqlQuerier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error) {
row := q.db.QueryRowContext(ctx, updateUserSecretByUserIDAndName,
arg.UpdateValue,
arg.Value,
arg.ValueKeyID,
arg.UpdateDescription,
arg.Description,
arg.UpdateEnvName,
arg.EnvName,
arg.UpdateFilePath,
arg.FilePath,
arg.UserID,
arg.Name,
)
var i UserSecret
err := row.Scan(
+99 -88
View File
@@ -149,94 +149,105 @@ VALUES (
RETURNING *;
-- name: CountAuditLogs :one
SELECT COUNT(*)
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
;
SELECT COUNT(*) FROM (
SELECT 1
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
-- Avoid a slow scan on a large table with joins. The caller
-- passes the count cap and we add 1 so the frontend can detect
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
-- -> NULL + 1 = NULL).
-- NOTE: Parameterizing this so that we can easily change from,
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
-- here if disabling the capping on a large table permanently.
-- This way the PG planner can plan parallel execution for
-- potential large wins.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
-- name: DeleteOldAuditLogConnectionEvents :exec
DELETE FROM audit_logs
+9 -6
View File
@@ -674,17 +674,20 @@ WHERE
status = 'running'::chat_status
AND heartbeat_at < @stale_threshold::timestamptz;
-- name: UpdateChatHeartbeat :execrows
-- Bumps the heartbeat timestamp for a running chat so that other
-- replicas know the worker is still alive.
-- name: UpdateChatHeartbeats :many
-- Bumps the heartbeat timestamp for the given set of chat IDs,
-- provided they are still running and owned by the specified
-- worker. Returns the IDs that were actually updated so the
-- caller can detect stolen or completed chats via set-difference.
UPDATE
chats
SET
heartbeat_at = NOW()
heartbeat_at = @now::timestamptz
WHERE
id = @id::uuid
id = ANY(@ids::uuid[])
AND worker_id = @worker_id::uuid
AND status = 'running'::chat_status;
AND status = 'running'::chat_status
RETURNING id;
-- name: GetChatDiffStatusByChatID :one
SELECT
+107 -105
View File
@@ -133,111 +133,113 @@ OFFSET
@offset_opt;
-- name: CountConnectionLogs :one
SELECT
COUNT(*) AS count
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
;
SELECT COUNT(*) AS count FROM (
SELECT 1
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
-- NOTE: See the CountAuditLogs LIMIT note.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
-- name: DeleteOldConnectionLogs :execrows
WITH old_logs AS (
+39 -18
View File
@@ -1,14 +1,26 @@
-- name: GetUserSecretByUserIDAndName :one
SELECT * FROM user_secrets
WHERE user_id = $1 AND name = $2;
-- name: GetUserSecret :one
SELECT * FROM user_secrets
WHERE id = $1;
SELECT *
FROM user_secrets
WHERE user_id = @user_id AND name = @name;
-- name: ListUserSecrets :many
SELECT * FROM user_secrets
WHERE user_id = $1
-- Returns metadata only (no value or value_key_id) for the
-- REST API list and get endpoints.
SELECT
id, user_id, name, description,
env_name, file_path,
created_at, updated_at
FROM user_secrets
WHERE user_id = @user_id
ORDER BY name ASC;
-- name: ListUserSecretsWithValues :many
-- Returns all columns including the secret value. Used by the
-- provisioner (build-time injection) and the agent manifest
-- (runtime injection).
SELECT *
FROM user_secrets
WHERE user_id = @user_id
ORDER BY name ASC;
-- name: CreateUserSecret :one
@@ -18,23 +30,32 @@ INSERT INTO user_secrets (
name,
description,
value,
value_key_id,
env_name,
file_path
) VALUES (
$1, $2, $3, $4, $5, $6, $7
@id,
@user_id,
@name,
@description,
@value,
@value_key_id,
@env_name,
@file_path
) RETURNING *;
-- name: UpdateUserSecret :one
-- name: UpdateUserSecretByUserIDAndName :one
UPDATE user_secrets
SET
description = $2,
value = $3,
env_name = $4,
file_path = $5,
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
updated_at = CURRENT_TIMESTAMP
WHERE user_id = @user_id AND name = @name
RETURNING *;
-- name: DeleteUserSecret :exec
-- name: DeleteUserSecretByUserIDAndName :exec
DELETE FROM user_secrets
WHERE id = $1;
WHERE user_id = @user_id AND name = @name;
+34
View File
@@ -298,6 +298,40 @@ neq(input.object.owner, "");
ExpectedSQL: p("'' = 'org-id'"),
VariableConverter: regosql.ChatConverter(),
},
{
Name: "AuditLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id IS NOT NULL") + " OR " +
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.AuditLogConverter(),
},
{
Name: "ConnectionLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id IS NOT NULL") + " OR " +
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.ConnectionLogConverter(),
},
}
for _, tc := range testCases {
+2 -2
View File
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
func AuditLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
// Audit logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
func ConnectionLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
// Connection logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
+114
View File
@@ -0,0 +1,114 @@
package sqltypes
import (
"fmt"
"strings"
"github.com/open-policy-agent/opa/ast"
"golang.org/x/xerrors"
)
var (
_ VariableMatcher = astUUIDVar{}
_ Node = astUUIDVar{}
_ SupportsEquality = astUUIDVar{}
)
// astUUIDVar is a variable that represents a UUID column. Unlike
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
// This allows PostgreSQL to use indexes on UUID columns.
type astUUIDVar struct {
Source RegoSource
FieldPath []string
ColumnString string
}
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
}
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
left, err := RegoVarPath(u.FieldPath, rego)
if err == nil && len(left) == 0 {
return astUUIDVar{
Source: RegoSource(rego.String()),
FieldPath: u.FieldPath,
ColumnString: u.ColumnString,
}, true
}
return nil, false
}
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
return u.ColumnString
}
// EqualsSQLString handles equality comparisons for UUID columns.
// Rego always produces string literals, so we accept AstString and
// cast the literal to ::uuid in the output SQL. This lets PG use
// native UUID indexes instead of falling back to text comparisons.
// nolint:revive
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
switch other.UseAs().(type) {
case AstString:
// The other side is a rego string literal like
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
// that casts the literal to uuid so PG can use indexes:
// column = 'val'::uuid
// instead of the text-based:
// 'val' = COALESCE(column::text, '')
s, ok := other.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString, got %T", other)
}
if s.Value == "" {
// Empty string in rego means "no value". Compare the
// column against NULL since UUID columns represent
// absent values as NULL, not empty strings.
op := "IS NULL"
if not {
op = "IS NOT NULL"
}
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
}
return fmt.Sprintf("%s %s '%s'::uuid",
u.ColumnString, equalsOp(not), s.Value), nil
case astUUIDVar:
return basicSQLEquality(cfg, not, u, other), nil
default:
return "", xerrors.Errorf("unsupported equality: %T %s %T",
u, equalsOp(not), other)
}
}
// ContainedInSQL implements SupportsContainedIn so that a UUID column
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
// array elements are rego strings, so we cast each to ::uuid.
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
arr, ok := haystack.(ASTArray)
if !ok {
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
}
if len(arr.Value) == 0 {
return "false", nil
}
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
values := make([]string, 0, len(arr.Value))
for _, v := range arr.Value {
s, ok := v.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString array element, got %T", v)
}
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
}
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
u.ColumnString,
strings.Join(values, ",")), nil
}
+2 -1
View File
@@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G
}
// Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams.
// nolint:exhaustruct // UserID is not obtained from the query parameters.
// nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters.
countFilter := database.CountAuditLogsParams{
RequestID: filter.RequestID,
ResourceID: filter.ResourceID,
@@ -123,6 +123,7 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey
}
// This MUST be kept in sync with the above
// nolint:exhaustruct // CountCap is not obtained from the query parameters.
countFilter := database.CountConnectionLogsParams{
OrganizationID: filter.OrganizationID,
WorkspaceOwner: filter.WorkspaceOwner,
+37 -19
View File
@@ -19,6 +19,7 @@ import (
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"
@@ -389,6 +390,7 @@ type MultiAgentController struct {
// connections to the destination
tickets map[uuid.UUID]map[uuid.UUID]struct{}
coordination *tailnet.BasicCoordination
sendGroup singleflight.Group
cancel context.CancelFunc
expireOldAgentsDone chan struct{}
@@ -418,28 +420,44 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.connectionTimes[agentID]
// If we don't have the agent, subscribe.
if !ok {
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
if m.coordination != nil {
err := m.coordination.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
err = xerrors.Errorf("subscribe agent: %w", err)
m.coordination.SendErr(err)
_ = m.coordination.Client.Close()
m.coordination = nil
return err
}
}
m.tickets[agentID] = map[uuid.UUID]struct{}{}
if ok {
m.connectionTimes[agentID] = time.Now()
m.mu.Unlock()
return nil
}
m.mu.Unlock()
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
_, err, _ := m.sendGroup.Do(agentID.String(), func() (interface{}, error) {
m.mu.Lock()
coord := m.coordination
m.mu.Unlock()
if coord == nil {
return nil, xerrors.New("no active coordination")
}
err := coord.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
return nil, err
}
m.mu.Lock()
m.tickets[agentID] = map[uuid.UUID]struct{}{}
m.mu.Unlock()
return nil, nil
})
if err != nil {
m.logger.Error(context.Background(), "ensureAgent send failed",
slog.F("agent_id", agentID), slog.Error(err))
return xerrors.Errorf("send AddTunnel: %w", err)
}
m.mu.Lock()
m.connectionTimes[agentID] = time.Now()
m.mu.Unlock()
return nil
}
+13 -8
View File
@@ -730,7 +730,10 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
if !ok {
return
}
log := s.Logger.With(slog.F("agent_id", appToken.AgentID))
log := s.Logger.With(
slog.F("agent_id", appToken.AgentID),
slog.F("workspace_id", appToken.WorkspaceID),
)
log.Debug(ctx, "resolved PTY request")
values := r.URL.Query()
@@ -765,19 +768,21 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
})
return
}
go httpapi.HeartbeatClose(ctx, s.Logger, cancel, conn)
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
go httpapi.HeartbeatClose(ctx, log, cancel, conn)
dialStart := time.Now()
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
if err != nil {
log.Debug(ctx, "dial workspace agent", slog.Error(err))
log.Debug(ctx, "dial workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return
}
defer release()
log.Debug(ctx, "dialed workspace agent")
log.Debug(ctx, "dialed workspace agent", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
// #nosec G115 - Safe conversion for terminal height/width which are expected to be within uint16 range (0-65535)
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
arp.Container = container
@@ -785,12 +790,12 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
arp.BackendType = backendType
})
if err != nil {
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return
}
defer ptNetConn.Close()
log.Debug(ctx, "obtained PTY")
log.Debug(ctx, "obtained PTY", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
report := newStatsReportFromSignedToken(*appToken)
s.collectStats(report)
@@ -800,7 +805,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
}()
agentssh.Bicopy(ctx, wsNetConn, ptNetConn)
log.Debug(ctx, "pty Bicopy finished")
log.Debug(ctx, "pty Bicopy finished", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
}
func (s *Server) collectStats(stats StatsReport) {
+124 -28
View File
@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"net/http"
"slices"
"strconv"
@@ -151,6 +152,12 @@ type Server struct {
inFlightChatStaleAfter time.Duration
chatHeartbeatInterval time.Duration
// heartbeatMu guards heartbeatRegistry.
heartbeatMu sync.Mutex
// heartbeatRegistry maps chat IDs to their cancel functions
// and workspace state for the centralized heartbeat loop.
heartbeatRegistry map[uuid.UUID]*heartbeatEntry
// wakeCh is signaled by SendMessage, EditMessage, CreateChat,
// and PromoteQueued so the run loop calls processOnce
// immediately instead of waiting for the next ticker.
@@ -706,6 +713,17 @@ type chatStreamState struct {
bufferRetainedAt time.Time
}
// heartbeatEntry tracks a single chat's cancel function and workspace
// state for the centralized heartbeat loop. Instead of spawning a
// per-chat goroutine, processChat registers an entry here and the
// single heartbeatLoop goroutine handles all chats.
type heartbeatEntry struct {
cancelWithCause context.CancelCauseFunc
chatID uuid.UUID
workspaceID uuid.NullUUID
logger slog.Logger
}
// resetDropCounters zeroes the rate-limiting state for both buffer
// and subscriber drop warnings. The caller must hold s.mu.
func (s *chatStreamState) resetDropCounters() {
@@ -2420,8 +2438,8 @@ func New(cfg Config) *Server {
clock: clk,
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
wakeCh: make(chan struct{}, 1),
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
//nolint:gocritic // The chat processor uses a scoped chatd context.
ctx = dbauthz.AsChatd(ctx)
@@ -2461,6 +2479,9 @@ func (p *Server) start(ctx context.Context) {
// to handle chats orphaned by crashed or redeployed workers.
p.recoverStaleChats(ctx)
// Single heartbeat loop for all chats on this replica.
go p.heartbeatLoop(ctx)
acquireTicker := p.clock.NewTicker(
p.pendingChatAcquireInterval,
"chatd",
@@ -2730,6 +2751,97 @@ func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) {
p.workspaceMCPToolsCache.Delete(chatID)
}
// registerHeartbeat enrolls a chat in the centralized batch
// heartbeat loop. Must be called after chatCtx is created.
func (p *Server) registerHeartbeat(entry *heartbeatEntry) {
p.heartbeatMu.Lock()
defer p.heartbeatMu.Unlock()
if _, exists := p.heartbeatRegistry[entry.chatID]; exists {
p.logger.Warn(context.Background(),
"duplicate heartbeat registration, skipping",
slog.F("chat_id", entry.chatID))
return
}
p.heartbeatRegistry[entry.chatID] = entry
}
// unregisterHeartbeat removes a chat from the centralized
// heartbeat loop when chat processing finishes.
func (p *Server) unregisterHeartbeat(chatID uuid.UUID) {
p.heartbeatMu.Lock()
defer p.heartbeatMu.Unlock()
delete(p.heartbeatRegistry, chatID)
}
// heartbeatLoop runs in a single goroutine, issuing one batch
// heartbeat query per interval for all registered chats.
func (p *Server) heartbeatLoop(ctx context.Context) {
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "batch-heartbeat")
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.heartbeatTick(ctx)
}
}
}
// heartbeatTick issues a single batch UPDATE for all running chats
// owned by this worker. Chats missing from the result set are
// interrupted (stolen by another replica or already completed).
func (p *Server) heartbeatTick(ctx context.Context) {
// Snapshot the registry under the lock.
p.heartbeatMu.Lock()
snapshot := maps.Clone(p.heartbeatRegistry)
p.heartbeatMu.Unlock()
if len(snapshot) == 0 {
return
}
// Collect the IDs we believe we own.
ids := slices.Collect(maps.Keys(snapshot))
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
// access for batch-updating heartbeats.
chatdCtx := dbauthz.AsChatd(ctx)
updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{
IDs: ids,
WorkerID: p.workerID,
Now: p.clock.Now(),
})
if err != nil {
p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err))
return
}
// Build a set of IDs that were successfully updated.
updated := make(map[uuid.UUID]struct{}, len(updatedIDs))
for _, id := range updatedIDs {
updated[id] = struct{}{}
}
// Interrupt registered chats that were not in the result
// (stolen by another replica or already completed).
for id, entry := range snapshot {
if _, ok := updated[id]; !ok {
entry.logger.Warn(ctx, "chat not in batch heartbeat result, interrupting")
entry.cancelWithCause(chatloop.ErrInterrupted)
continue
}
// Bump workspace usage for surviving chats.
newWsID := p.trackWorkspaceUsage(ctx, entry.chatID, entry.workspaceID, entry.logger)
// Update workspace ID in the registry for next tick.
p.heartbeatMu.Lock()
if current, exists := p.heartbeatRegistry[id]; exists {
current.workspaceID = newWsID
}
p.heartbeatMu.Unlock()
}
}
func (p *Server) Subscribe(
ctx context.Context,
chatID uuid.UUID,
@@ -3575,33 +3687,17 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
}
}()
// Periodically update the heartbeat so other replicas know this
// worker is still alive. The goroutine stops when chatCtx is
// canceled (either by completion or interruption).
go func() {
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "heartbeat")
defer ticker.Stop()
for {
select {
case <-chatCtx.Done():
return
case <-ticker.C:
rows, err := p.db.UpdateChatHeartbeat(chatCtx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: p.workerID,
})
if err != nil {
logger.Warn(chatCtx, "failed to update chat heartbeat", slog.Error(err))
continue
}
if rows == 0 {
cancel(chatloop.ErrInterrupted)
return
}
chat.WorkspaceID = p.trackWorkspaceUsage(chatCtx, chat.ID, chat.WorkspaceID, logger)
}
}
}()
// Register with the centralized heartbeat loop instead of
// running a per-chat goroutine. The loop issues a single batch
// UPDATE for all chats on this worker and detects stolen chats
// via set-difference.
p.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel,
chatID: chat.ID,
workspaceID: chat.WorkspaceID,
logger: logger,
})
defer p.unregisterHeartbeat(chat.ID)
// Start buffering stream events BEFORE publishing the running
// status. This closes a race where a subscriber sees
+129
View File
@@ -21,6 +21,7 @@ import (
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
@@ -2071,6 +2072,7 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
workerID: workerID,
chatHeartbeatInterval: time.Minute,
configCache: newChatConfigCache(ctx, db, clock),
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
// Publish a stale "pending" notification on the control channel
@@ -2133,3 +2135,130 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
require.Equal(t, database.ChatStatusError, finalStatus,
"processChat should have reached runChat (error), not been interrupted (waiting)")
}
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
// batch heartbeat UPDATE does not return a registered chat's ID
// (because another replica stole it or it was completed), the
// heartbeat tick cancels that chat's context with ErrInterrupted
// while leaving surviving chats untouched.
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
workerID := uuid.New()
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: workerID,
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
// Create three chats with independent cancel functions.
chat1 := uuid.New()
chat2 := uuid.New()
chat3 := uuid.New()
_, cancel1 := context.WithCancelCause(ctx)
_, cancel2 := context.WithCancelCause(ctx)
ctx3, cancel3 := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel1,
chatID: chat1,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel2,
chatID: chat2,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel3,
chatID: chat3,
logger: logger,
})
// The batch UPDATE returns only chat1 and chat2 —
// chat3 was "stolen" by another replica.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
require.Equal(t, workerID, params.WorkerID)
require.Len(t, params.IDs, 3)
// Return only chat1 and chat2 as surviving.
return []uuid.UUID{chat1, chat2}, nil
},
)
server.heartbeatTick(ctx)
// chat3's context should be canceled with ErrInterrupted.
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
"stolen chat should be interrupted")
// chat3 should have been removed from the registry by
// unregister (in production this happens via defer in
// processChat). The heartbeat tick itself does not
// unregister — it only cancels. Verify the entry is
// still present (processChat's defer would clean it up).
server.heartbeatMu.Lock()
_, chat1Exists := server.heartbeatRegistry[chat1]
_, chat2Exists := server.heartbeatRegistry[chat2]
_, chat3Exists := server.heartbeatRegistry[chat3]
server.heartbeatMu.Unlock()
require.True(t, chat1Exists, "surviving chat1 should remain registered")
require.True(t, chat2Exists, "surviving chat2 should remain registered")
require.True(t, chat3Exists,
"stolen chat3 should still be in registry (processChat defer removes it)")
}
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
// transient database failure causes the tick to log and return
// without canceling any registered chats.
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: uuid.New(),
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
chatID := uuid.New()
chatCtx, cancel := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel,
chatID: chatID,
logger: logger,
})
// Simulate a transient DB error.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
nil, xerrors.New("connection reset"),
)
server.heartbeatTick(ctx)
// Chat should NOT be interrupted — the tick logged and
// returned early.
require.NoError(t, chatCtx.Err(),
"chat context should not be canceled on transient DB error")
}
+12 -7
View File
@@ -474,7 +474,7 @@ func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
}
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
@@ -501,19 +501,24 @@ func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
})
require.NoError(t, err)
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
// Wrong worker_id should return no IDs.
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
WorkerID: uuid.New(),
Now: time.Now(),
})
require.NoError(t, err)
require.Equal(t, int64(0), rows)
require.Empty(t, ids)
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
// Correct worker_id should return the chat's ID.
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
WorkerID: workerID,
Now: time.Now(),
})
require.NoError(t, err)
require.Equal(t, int64(1), rows)
require.Len(t, ids, 1)
require.Equal(t, chat.ID, ids[0])
}
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
+34 -5
View File
@@ -9,6 +9,7 @@ import (
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"sync"
"time"
@@ -49,10 +50,11 @@ const connectTimeout = 10 * time.Second
const toolCallTimeout = 60 * time.Second
// ConnectAll connects to all configured MCP servers, discovers
// their tools, and returns them as fantasy.AgentTool values. It
// skips servers that fail to connect and logs warnings. The
// returned cleanup function must be called to close all
// connections.
// their tools, and returns them as fantasy.AgentTool values.
// Tools are sorted by their prefixed name so callers
// receive a deterministic order. It skips servers that fail to
// connect and logs warnings. The returned cleanup function
// must be called to close all connections.
func ConnectAll(
ctx context.Context,
logger slog.Logger,
@@ -108,7 +110,9 @@ func ConnectAll(
}
mu.Lock()
clients = append(clients, mcpClient)
if mcpClient != nil {
clients = append(clients, mcpClient)
}
tools = append(tools, serverTools...)
mu.Unlock()
return nil
@@ -119,6 +123,31 @@ func ConnectAll(
// discarded.
_ = eg.Wait()
// Sort tools by prefixed name for deterministic ordering
// regardless of goroutine completion order. Ties, possible
// when the __ separator produces ambiguous prefixed names,
// are broken by config ID. Stable prompt construction
// depends on consistent tool ordering.
slices.SortFunc(tools, func(a, b fantasy.AgentTool) int {
// All tools in this slice are mcpToolWrapper values
// created by connectOne above, so these checked
// assertions should always succeed. The config ID
// tiebreaker resolves the __ separator ambiguity
// documented at the top of this file.
aTool, ok := a.(MCPToolIdentifier)
if !ok {
panic(fmt.Sprintf("unexpected tool type %T", a))
}
bTool, ok := b.(MCPToolIdentifier)
if !ok {
panic(fmt.Sprintf("unexpected tool type %T", b))
}
return cmp.Or(
cmp.Compare(a.Info().Name, b.Info().Name),
cmp.Compare(aTool.MCPServerConfigID().String(), bTool.MCPServerConfigID().String()),
)
})
return tools, cleanup
}
+126
View File
@@ -63,6 +63,17 @@ func greetTool() mcpserver.ServerTool {
}
}
// makeTool returns a ServerTool with the given name and a
// no-op handler that always returns "ok".
func makeTool(name string) mcpserver.ServerTool {
return mcpserver.ServerTool{
Tool: mcp.NewTool(name),
Handler: func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText("ok"), nil
},
}
}
// makeConfig builds a database.MCPServerConfig suitable for tests.
func makeConfig(slug, url string) database.MCPServerConfig {
return database.MCPServerConfig{
@@ -198,6 +209,121 @@ func TestConnectAll_MultipleServers(t *testing.T) {
assert.Contains(t, names, "beta__greet")
}
func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("filtered", ts.URL)
cfg.ToolAllowList = []string{"greet"}
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{cfg},
nil,
)
require.Empty(t, tools)
assert.NotPanics(t, cleanup)
}
func TestConnectAll_DeterministicOrder(t *testing.T) {
t.Parallel()
t.Run("AcrossServers", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts1 := newTestMCPServer(t, makeTool("zebra"))
ts2 := newTestMCPServer(t, makeTool("alpha"))
ts3 := newTestMCPServer(t, makeTool("middle"))
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{
makeConfig("srv3", ts3.URL),
makeConfig("srv1", ts1.URL),
makeConfig("srv2", ts2.URL),
},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 3)
// Sorted by full prefixed name (slug__tool), so slug
// order determines the sequence, not the tool name.
assert.Equal(t,
[]string{"srv1__zebra", "srv2__alpha", "srv3__middle"},
toolNames(tools),
)
})
t.Run("WithMultiToolServer", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
multi := newTestMCPServer(t, makeTool("zeta"), makeTool("beta"))
other := newTestMCPServer(t, makeTool("gamma"))
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{
makeConfig("zzz", multi.URL),
makeConfig("aaa", other.URL),
},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 3)
assert.Equal(t,
[]string{"aaa__gamma", "zzz__beta", "zzz__zeta"},
toolNames(tools),
)
})
t.Run("TiebreakByConfigID", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts1 := newTestMCPServer(t, makeTool("b__z"))
ts2 := newTestMCPServer(t, makeTool("z"))
// Use fixed UUIDs so the tiebreaker order is
// predictable. Both servers produce the same prefixed
// name, a__b__z, due to the __ separator ambiguity.
cfg1 := makeConfig("a", ts1.URL)
cfg1.ID = uuid.MustParse("00000000-0000-0000-0000-000000000002")
cfg2 := makeConfig("a__b", ts2.URL)
cfg2.ID = uuid.MustParse("00000000-0000-0000-0000-000000000001")
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{cfg1, cfg2},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 2)
assert.Equal(t, []string{"a__b__z", "a__b__z"}, toolNames(tools))
id0 := tools[0].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
id1 := tools[1].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
assert.Equal(t, cfg2.ID, id0, "lower config ID should sort first")
assert.Equal(t, cfg1.ID, id1, "higher config ID should sort second")
})
}
func TestConnectAll_AuthHeaders(t *testing.T) {
t.Parallel()
ctx := context.Background()
+1
View File
@@ -212,6 +212,7 @@ type AuditLogsRequest struct {
type AuditLogResponse struct {
AuditLogs []AuditLog `json:"audit_logs"`
Count int64 `json:"count"`
CountCap int64 `json:"count_cap"`
}
type CreateTestAuditLogRequest struct {
+1
View File
@@ -96,6 +96,7 @@ type ConnectionLogsRequest struct {
type ConnectionLogResponse struct {
ConnectionLogs []ConnectionLog `json:"connection_logs"`
Count int64 `json:"count"`
CountCap int64 `json:"count_cap"`
}
func (c *Client) ConnectionLogs(ctx context.Context, req ConnectionLogsRequest) (ConnectionLogResponse, error) {
+2 -1
View File
@@ -90,7 +90,8 @@ curl -X GET http://coder-server:8080/api/v2/audit?limit=0 \
"user_agent": "string"
}
],
"count": 0
"count": 0,
"count_cap": 0
}
```
+2 -1
View File
@@ -291,7 +291,8 @@ curl -X GET http://coder-server:8080/api/v2/connectionlog?limit=0 \
"workspace_owner_username": "string"
}
],
"count": 0
"count": 0,
"count_cap": 0
}
```
+6 -2
View File
@@ -1740,7 +1740,8 @@
"user_agent": "string"
}
],
"count": 0
"count": 0,
"count_cap": 0
}
```
@@ -1750,6 +1751,7 @@
|--------------|-------------------------------------------------|----------|--------------|-------------|
| `audit_logs` | array of [codersdk.AuditLog](#codersdkauditlog) | false | | |
| `count` | integer | false | | |
| `count_cap` | integer | false | | |
## codersdk.AuthMethod
@@ -2173,7 +2175,8 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
"workspace_owner_username": "string"
}
],
"count": 0
"count": 0,
"count_cap": 0
}
```
@@ -2183,6 +2186,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|-------------------|-----------------------------------------------------------|----------|--------------|-------------|
| `connection_logs` | array of [codersdk.ConnectionLog](#codersdkconnectionlog) | false | | |
| `count` | integer | false | | |
| `count_cap` | integer | false | | |
## codersdk.ConnectionLogSSHInfo
+6
View File
@@ -16,6 +16,9 @@ import (
"github.com/coder/coder/v2/codersdk"
)
// NOTE: See the auditLogCountCap note.
const connectionLogCountCap = 2000
// @Summary Get connection logs
// @ID get-connection-logs
// @Security CoderSessionToken
@@ -49,6 +52,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
// #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range
filter.LimitOpt = int32(page.Limit)
countFilter.CountCap = connectionLogCountCap
count, err := api.Database.CountConnectionLogs(ctx, countFilter)
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
@@ -63,6 +67,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
ConnectionLogs: []codersdk.ConnectionLog{},
Count: 0,
CountCap: connectionLogCountCap,
})
return
}
@@ -80,6 +85,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
ConnectionLogs: convertConnectionLogs(dblogs),
Count: count,
CountCap: connectionLogCountCap,
})
}
+27 -7
View File
@@ -12,6 +12,7 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
gProto "google.golang.org/protobuf/proto"
@@ -33,9 +34,9 @@ const (
eventReadyForHandshake = "tailnet_ready_for_handshake"
HeartbeatPeriod = time.Second * 2
MissedHeartbeats = 3
numQuerierWorkers = 10
numQuerierWorkers = 40
numBinderWorkers = 10
numTunnelerWorkers = 10
numTunnelerWorkers = 20
numHandshakerWorkers = 5
dbMaxBackoff = 10 * time.Second
cleanupPeriod = time.Hour
@@ -770,6 +771,9 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo
for k := range m.sent {
if _, ok := best[k]; !ok {
m.logger.Debug(m.ctx, "peer no longer in best mappings, sending DISCONNECTED",
slog.F("peer_id", k),
)
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
Id: agpl.UUIDToByteSlice(k),
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
@@ -820,6 +824,8 @@ type querier struct {
mu sync.Mutex
mappers map[mKey]*mapper
healthy bool
resyncGroup singleflight.Group
}
func newQuerier(ctx context.Context,
@@ -958,7 +964,7 @@ func (q *querier) cleanupConn(c *connIO) {
// maxBatchSize is the maximum number of keys to process in a single batch
// query.
const maxBatchSize = 50
const maxBatchSize = 200
func (q *querier) peerUpdateWorker() {
defer q.wg.Done()
@@ -1207,8 +1213,13 @@ func (q *querier) subscribe() {
func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
q.logger.Warn(q.ctx, "pubsub may have dropped peer updates")
// we need to schedule a full resync of peer mappings
q.resyncPeerMappings()
// Schedule a full resync asynchronously so we don't block the
// pubsub drain goroutine. Singleflight coalesces concurrent
// resync requests.
go q.resyncGroup.Do("resync", func() (any, error) {
q.resyncPeerMappings()
return nil, nil
})
return
}
if err != nil {
@@ -1234,8 +1245,13 @@ func (q *querier) listenPeer(_ context.Context, msg []byte, err error) {
func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
q.logger.Warn(q.ctx, "pubsub may have dropped tunnel updates")
// we need to schedule a full resync of peer mappings
q.resyncPeerMappings()
// Schedule a full resync asynchronously so we don't block the
// pubsub drain goroutine. Singleflight coalesces concurrent
// resync requests.
go q.resyncGroup.Do("resync", func() (any, error) {
q.resyncPeerMappings()
return nil, nil
})
return
}
if err != nil {
@@ -1601,6 +1617,10 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
// the only mapping available for it. Newer mappings will take
// precedence.
m.kind = proto.CoordinateResponse_PeerUpdate_LOST
h.logger.Debug(h.ctx, "mapping rewritten to LOST due to missed heartbeats",
slog.F("peer_id", m.peer),
slog.F("coordinator_id", m.coordinator),
)
}
}
+2 -2
View File
@@ -76,11 +76,11 @@ replace github.com/aquasecurity/trivy => github.com/coder/trivy v0.0.0-202603091
// https://github.com/spf13/afero/pull/487
replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713-f06e86036696
// Forked from kylecarbs/fantasy (cj/go1.25 branch) which adds:
// Forked from coder/fantasy (cj/go1.25 branch) which adds:
// 1) Anthropic computer use + thinking effort
// 2) Go 1.25 downgrade for Windows CI compat
// 3) ibetitsmike/fantasy#4 skip ephemeral replay items when store=false
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8
replace charm.land/fantasy => github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8
replace github.com/charmbracelet/anthropic-sdk-go => github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab
+2 -2
View File
@@ -322,6 +322,8 @@ github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41 h1:SBN/DA63+ZHwu
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41/go.mod h1:I9ULxr64UaOSUv7hcb3nX4kowodJCVS7vt7VVJk/kW4=
github.com/coder/clistat v1.2.1 h1:P9/10njXMyj5cWzIU5wkRsSy5LVQH49+tcGMsAgWX0w=
github.com/coder/clistat v1.2.1/go.mod h1:m7SC0uj88eEERgvF8Kn6+w6XF21BeSr+15f7GoLAw0A=
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:n+6v+yT1B6V4oSGPmXFh7mul1E+RzG9rnqp50Vb7M/w=
github.com/coder/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
github.com/coder/flog v1.1.0 h1:kbAes1ai8fIS5OeV+QAnKBQE22ty1jRF/mcAwHpLBa4=
github.com/coder/flog v1.1.0/go.mod h1:UQlQvrkJBvnRGo69Le8E24Tcl5SJleAAR7gYEHzAmdQ=
github.com/coder/go-httpstat v0.0.0-20230801153223-321c88088322 h1:m0lPZjlQ7vdVpRBPKfYIFlmgevoTkBxB10wv6l2gOaU=
@@ -813,8 +815,6 @@ github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:5UMY
github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab/go.mod h1:hqlYqR7uPKOKfnNeicUbZp0Ps0GeYFlKYtwh5HGDCx8=
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3 h1:Z9/bo5PSeMutpdiKYNt/TTSfGM1Ll0naj3QzYX9VxTc=
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3/go.mod h1:BUGjjsD+ndS6eX37YgTchSEG+Jg9Jv1GiZs9sqPqztk=
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8 h1:fZ0208U3B438fDSHCc/GNioPIyaFqn6eBsQTO61QtrI=
github.com/kylecarbs/fantasy v0.0.0-20260325145725-112927d9b6d8/go.mod h1:ktfNX0xDpIKeggZbP/j5IYJci6pyMOR3WmZSfz9XLYw=
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae h1:xlFZNX4nnxpj9Cf6mTwD3pirXGNtBJ/6COsf9iZmsL0=
github.com/kylecarbs/openai-go/v3 v3.0.0-20260319113850-9477dcaedcae/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e h1:OP0ZMFeZkUnOzTFRfpuK3m7Kp4fNvC6qN+exwj7aI4M=
+44 -35
View File
@@ -68,17 +68,17 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
return xerrors.Errorf("detecting branch: %w", err)
}
// Match standard release branches (release/2.32) and RC
// branches (release/2.32-rc.0).
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)(?:-rc\.(\d+))?$`)
// Match release branches (release/X.Y). RCs are tagged
// from main, not from release branches.
branchRe := regexp.MustCompile(`^release/(\d+)\.(\d+)$`)
m := branchRe.FindStringSubmatch(currentBranch)
if m == nil {
warnf(w, "Current branch %q is not a release branch (release/X.Y or release/X.Y-rc.N).", currentBranch)
warnf(w, "Current branch %q is not a release branch (release/X.Y).", currentBranch)
branchInput, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Enter the release branch to use (e.g. release/2.21 or release/2.21-rc.0)",
Text: "Enter the release branch to use (e.g. release/2.21)",
Validate: func(s string) error {
if !branchRe.MatchString(s) {
return xerrors.New("must be in format release/X.Y or release/X.Y-rc.N (e.g. release/2.21 or release/2.21-rc.0)")
return xerrors.New("must be in format release/X.Y (e.g. release/2.21)")
}
return nil
},
@@ -91,10 +91,6 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
}
branchMajor, _ := strconv.Atoi(m[1])
branchMinor, _ := strconv.Atoi(m[2])
branchRC := -1 // -1 means not an RC branch.
if m[3] != "" {
branchRC, _ = strconv.Atoi(m[3])
}
successf(w, "Using release branch: %s", currentBranch)
// --- Fetch & sync check ---
@@ -138,31 +134,44 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
}
}
// changelogBaseRef is the git ref used as the starting point
// for release notes generation. When a tag already exists in
// this minor series we use it directly. For the first release
// on a new minor no matching tag exists, so we compute the
// merge-base with the previous minor's release branch instead.
// This works even when that branch has no tags yet (it was
// just cut and pushed). As a last resort we fall back to the
// latest reachable tag from a previous minor.
var changelogBaseRef string
if prevVersion != nil {
changelogBaseRef = prevVersion.String()
} else {
prevReleaseBranch := fmt.Sprintf("release/%d.%d", branchMajor, branchMinor-1)
if err := gitRun("fetch", "--quiet", "origin", prevReleaseBranch); err != nil {
warnf(w, "Could not fetch %s: %v", prevReleaseBranch, err)
}
if mb, mbErr := gitOutput("merge-base", "HEAD", "origin/"+prevReleaseBranch); mbErr == nil && mb != "" {
changelogBaseRef = mb
infof(w, "Using merge-base with %s as changelog base: %s", prevReleaseBranch, mb[:12])
} else {
// No previous release branch found; fall back to
// the latest reachable tag from a previous minor.
for _, t := range mergedTags {
if t.Major == branchMajor && t.Minor < branchMinor {
changelogBaseRef = t.String()
break
}
}
}
}
var suggested version
if prevVersion == nil {
infof(w, "No previous release tag found on this branch.")
suggested = version{Major: branchMajor, Minor: branchMinor, Patch: 0}
if branchRC >= 0 {
suggested.Pre = fmt.Sprintf("rc.%d", branchRC)
}
} else {
infof(w, "Previous release tag: %s", prevVersion.String())
if branchRC >= 0 {
// On an RC branch, suggest the next RC for
// the same base version.
nextRC := 0
if prevVersion.IsRC() {
nextRC = prevVersion.rcNumber() + 1
}
suggested = version{
Major: prevVersion.Major,
Minor: prevVersion.Minor,
Patch: prevVersion.Patch,
Pre: fmt.Sprintf("rc.%d", nextRC),
}
} else {
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
}
suggested = version{Major: prevVersion.Major, Minor: prevVersion.Minor, Patch: prevVersion.Patch + 1}
}
fmt.Fprintln(w)
@@ -366,8 +375,8 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
infof(w, "Generating release notes...")
commitRange := "HEAD"
if prevVersion != nil {
commitRange = prevVersion.String() + "..HEAD"
if changelogBaseRef != "" {
commitRange = changelogBaseRef + "..HEAD"
}
commits, err := commitLog(commitRange)
@@ -473,16 +482,16 @@ func runRelease(ctx context.Context, inv *serpent.Invocation, executor ReleaseEx
}
if !hasContent {
prevStr := "the beginning of time"
if prevVersion != nil {
prevStr = prevVersion.String()
if changelogBaseRef != "" {
prevStr = changelogBaseRef
}
fmt.Fprintf(&notes, "\n_No changes since %s._\n", prevStr)
}
// Compare link.
if prevVersion != nil {
if changelogBaseRef != "" {
fmt.Fprintf(&notes, "\nCompare: [`%s...%s`](https://github.com/%s/%s/compare/%s...%s)\n",
prevVersion, newVersion, owner, repo, prevVersion, newVersion)
changelogBaseRef, newVersion, owner, repo, changelogBaseRef, newVersion)
}
// Container image.
+2
View File
@@ -913,6 +913,7 @@ export interface AuditLog {
export interface AuditLogResponse {
readonly audit_logs: readonly AuditLog[];
readonly count: number;
readonly count_cap: number;
}
// From codersdk/audit.go
@@ -2269,6 +2270,7 @@ export interface ConnectionLog {
export interface ConnectionLogResponse {
readonly connection_logs: readonly ConnectionLog[];
readonly count: number;
readonly count_cap: number;
}
// From codersdk/connectionlog.go
+31 -57
View File
@@ -1,7 +1,6 @@
import type { Interpolation, Theme } from "@emotion/react";
import type { FC, HTMLAttributes } from "react";
import type { LogLevel } from "#/api/typesGenerated";
import { MONOSPACE_FONT_FAMILY } from "#/theme/constants";
import { cn } from "#/utils/cn";
const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
@@ -17,65 +16,40 @@ type LogLineProps = {
level: LogLevel;
} & HTMLAttributes<HTMLPreElement>;
export const LogLine: FC<LogLineProps> = ({ level, ...divProps }) => {
export const LogLine: FC<LogLineProps> = ({ level, className, ...props }) => {
return (
<pre
css={styles.line}
className={`${level} ${divProps.className} logs-line`}
{...divProps}
{...props}
className={cn(
"logs-line",
"m-0 break-all flex items-center h-auto",
"text-[13px] text-content-primary font-mono",
level === "error" &&
"bg-surface-error text-content-error [&_.dashed-line]:bg-border-error",
level === "debug" &&
"bg-surface-sky text-content-sky [&_.dashed-line]:bg-border-sky",
level === "warn" &&
"bg-surface-warning text-content-warning [&_.dashed-line]:bg-border-warning",
className,
)}
style={{
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
}}
/>
);
};
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = (props) => {
return <pre css={styles.prefix} {...props} />;
export const LogLinePrefix: FC<HTMLAttributes<HTMLSpanElement>> = ({
className,
...props
}) => {
return (
<pre
className={cn(
"select-none m-0 inline-block text-content-secondary mr-6",
className,
)}
{...props}
/>
);
};
const styles = {
line: (theme) => ({
margin: 0,
wordBreak: "break-all",
display: "flex",
alignItems: "center",
fontSize: 13,
color: theme.palette.text.primary,
fontFamily: MONOSPACE_FONT_FAMILY,
height: "auto",
padding: `0 var(--log-line-side-padding, ${DEFAULT_LOG_LINE_SIDE_PADDING}px)`,
"&.error": {
backgroundColor: theme.roles.error.background,
color: theme.roles.error.text,
"& .dashed-line": {
backgroundColor: theme.roles.error.outline,
},
},
"&.debug": {
backgroundColor: theme.roles.notice.background,
color: theme.roles.notice.text,
"& .dashed-line": {
backgroundColor: theme.roles.notice.outline,
},
},
"&.warn": {
backgroundColor: theme.roles.warning.background,
color: theme.roles.warning.text,
"& .dashed-line": {
backgroundColor: theme.roles.warning.outline,
},
},
}),
prefix: (theme) => ({
userSelect: "none",
margin: 0,
display: "inline-block",
color: theme.palette.text.secondary,
marginRight: 24,
}),
} satisfies Record<string, Interpolation<Theme>>;
+12 -17
View File
@@ -1,6 +1,6 @@
import type { Interpolation, Theme } from "@emotion/react";
import dayjs from "dayjs";
import type { FC } from "react";
import { cn } from "#/utils/cn";
import { type Line, LogLine, LogLinePrefix } from "./LogLine";
export const DEFAULT_LOG_LINE_SIDE_PADDING = 24;
@@ -17,7 +17,17 @@ export const Logs: FC<LogsProps> = ({
className = "",
}) => {
return (
<div css={styles.root} className={`${className} logs-container`}>
<div
className={cn(
"logs-container",
"min-h-40 py-2 rounded-lg overflow-x-auto bg-surface-primary",
"[&:not(:last-child)]:border-0",
"[&:not(:last-child)]:border-solid",
"[&:not(:last-child)]:border-b-border",
"[&:not(:last-child)]:rounded-none",
className,
)}
>
<div className="min-w-fit">
{lines.map((line) => (
<LogLine key={line.id} level={line.level}>
@@ -33,18 +43,3 @@ export const Logs: FC<LogsProps> = ({
</div>
);
};
const styles = {
root: (theme) => ({
minHeight: 156,
padding: "8px 0",
borderRadius: 8,
overflowX: "auto",
background: theme.palette.background.default,
"&:not(:last-child)": {
borderBottom: `1px solid ${theme.palette.divider}`,
borderRadius: 0,
},
}),
} satisfies Record<string, Interpolation<Theme>>;
@@ -7,6 +7,7 @@ type PaginationHeaderProps = {
limit: number;
totalRecords: number | undefined;
currentOffsetStart: number | undefined;
countIsCapped?: boolean;
// Temporary escape hatch until Workspaces can be switched over to using
// PaginationContainer
@@ -18,6 +19,7 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
limit,
totalRecords,
currentOffsetStart,
countIsCapped,
className,
}) => {
const theme = useTheme();
@@ -52,10 +54,16 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
<strong>
{(
currentOffsetStart +
Math.min(limit - 1, totalRecords - currentOffsetStart)
(countIsCapped
? limit - 1
: Math.min(limit - 1, totalRecords - currentOffsetStart))
).toLocaleString()}
</strong>{" "}
of <strong>{totalRecords.toLocaleString()}</strong>{" "}
of{" "}
<strong>
{totalRecords.toLocaleString()}
{countIsCapped && "+"}
</strong>{" "}
{paginationUnitLabel}
</div>
)}
@@ -18,6 +18,7 @@ export const mockPaginationResultBase: ResultBase = {
limit: 25,
hasNextPage: false,
hasPreviousPage: false,
countIsCapped: false,
goToPreviousPage: () => {},
goToNextPage: () => {},
goToFirstPage: () => {},
@@ -33,6 +34,7 @@ export const mockInitialRenderResult: PaginationResult = {
hasPreviousPage: false,
totalRecords: undefined,
totalPages: undefined,
countIsCapped: false,
};
export const mockSuccessResult: PaginationResult = {
@@ -94,7 +94,7 @@ export const FirstPageWithTonsOfData: Story = {
currentPage: 2,
currentOffsetStart: 1000,
totalRecords: 123_456,
totalPages: 1235,
totalPages: 4939,
hasPreviousPage: false,
hasNextPage: true,
isPlaceholderData: false,
@@ -135,3 +135,54 @@ export const SecondPageWithData: Story = {
children: <div>New data for page 2</div>,
},
};
export const CappedCountFirstPage: Story = {
args: {
query: {
...mockPaginationResultBase,
isSuccess: true,
currentPage: 1,
currentOffsetStart: 1,
totalRecords: 2000,
totalPages: 80,
hasPreviousPage: false,
hasNextPage: true,
isPlaceholderData: false,
countIsCapped: true,
},
},
};
export const CappedCountMiddlePage: Story = {
args: {
query: {
...mockPaginationResultBase,
isSuccess: true,
currentPage: 3,
currentOffsetStart: 51,
totalRecords: 2000,
totalPages: 80,
hasPreviousPage: true,
hasNextPage: true,
isPlaceholderData: false,
countIsCapped: true,
},
},
};
export const CappedCountBeyondKnownPages: Story = {
args: {
query: {
...mockPaginationResultBase,
isSuccess: true,
currentPage: 85,
currentOffsetStart: 2101,
totalRecords: 2000,
totalPages: 85,
hasPreviousPage: true,
hasNextPage: true,
isPlaceholderData: false,
countIsCapped: true,
},
},
};
@@ -27,12 +27,14 @@ export const PaginationContainer: FC<PaginationProps> = ({
totalRecords={query.totalRecords}
currentOffsetStart={query.currentOffsetStart}
paginationUnitLabel={paginationUnitLabel}
countIsCapped={query.countIsCapped}
className="justify-end"
/>
{query.isSuccess && (
<PaginationWidgetBase
totalRecords={query.totalRecords}
totalPages={query.totalPages}
currentPage={query.currentPage}
pageSize={query.limit}
onPageChange={query.onPageChange}
@@ -12,6 +12,10 @@ export type PaginationWidgetBaseProps = {
hasPreviousPage?: boolean;
hasNextPage?: boolean;
/** Override the computed totalPages.
* Used when, e.g., the row count is capped and the user navigates beyond
* the known range, so totalPages stays at least as high as currentPage. */
totalPages?: number;
};
export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
@@ -21,8 +25,9 @@ export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
onPageChange,
hasPreviousPage,
hasNextPage,
totalPages: totalPagesProp,
}) => {
const totalPages = Math.ceil(totalRecords / pageSize);
const totalPages = totalPagesProp ?? Math.ceil(totalRecords / pageSize);
if (totalPages < 2) {
return null;
-1
View File
@@ -1 +0,0 @@
export { useTabOverflowKebabMenu } from "./useTabOverflowKebabMenu";
@@ -0,0 +1,274 @@
import {
type RefObject,
useCallback,
useEffect,
useRef,
useState,
} from "react";
type TabValue = {
value: string;
};
type UseKebabMenuOptions<T extends TabValue> = {
tabs: readonly T[];
enabled: boolean;
isActive: boolean;
overflowTriggerWidth?: number;
};
type UseKebabMenuResult<T extends TabValue> = {
containerRef: RefObject<HTMLDivElement | null>;
visibleTabs: T[];
overflowTabs: T[];
getTabMeasureProps: (tabValue: string) => Record<string, string>;
};
const ALWAYS_VISIBLE_TABS_COUNT = 1;
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
/**
* Splits tabs into visible and overflow groups based on container width.
*
* Tabs must render with `getTabMeasureProps()` so this hook can measure
* trigger widths from the DOM.
*/
export const useKebabMenu = <T extends TabValue>({
tabs,
enabled,
isActive,
overflowTriggerWidth = 44,
}: UseKebabMenuOptions<T>): UseKebabMenuResult<T> => {
const containerRef = useRef<HTMLDivElement>(null);
const tabsRef = useRef<readonly T[]>(tabs);
tabsRef.current = tabs;
const previousTabsRef = useRef<readonly T[]>(tabs);
const availableWidthRef = useRef<number | null>(null);
// Width cache prevents oscillation when overflow tabs are not mounted.
const tabWidthByValueRef = useRef<Record<string, number>>({});
const [overflowTabValues, setTabValues] = useState<string[]>([]);
const recalculateOverflow = useCallback(
(availableWidth: number) => {
if (!enabled || !isActive) {
// Keep this update idempotent to avoid render loops.
setTabValues((currentValues) => {
if (currentValues.length === 0) {
return currentValues;
}
return [];
});
return;
}
const container = containerRef.current;
if (!container) {
return;
}
const currentTabs = tabsRef.current;
const tabWidthByValue = measureTabWidths({
tabs: currentTabs,
container,
previousTabWidthByValue: tabWidthByValueRef.current,
});
tabWidthByValueRef.current = tabWidthByValue;
const nextOverflowValues = calculateTabValues({
tabs: currentTabs,
availableWidth,
tabWidthByValue,
overflowTriggerWidth,
});
setTabValues((currentValues) => {
// Avoid state updates when the computed overflow did not change.
if (areStringArraysEqual(currentValues, nextOverflowValues)) {
return currentValues;
}
return nextOverflowValues;
});
},
[enabled, isActive, overflowTriggerWidth],
);
useEffect(() => {
if (previousTabsRef.current === tabs) {
// No change in tabs, no need to recalculate.
return;
}
previousTabsRef.current = tabs;
if (availableWidthRef.current === null) {
// First mount, no width available yet.
return;
}
recalculateOverflow(availableWidthRef.current);
}, [recalculateOverflow, tabs]);
useEffect(() => {
const container = containerRef.current;
if (!container) {
return;
}
// Recompute whenever ResizeObserver reports a container width change.
const observer = new ResizeObserver(([entry]) => {
if (!entry) {
return;
}
availableWidthRef.current = entry.contentRect.width;
recalculateOverflow(entry.contentRect.width);
});
observer.observe(container);
return () => observer.disconnect();
}, [recalculateOverflow]);
const overflowTabValuesSet = new Set(overflowTabValues);
const { visibleTabs, overflowTabs } = tabs.reduce<{
visibleTabs: T[];
overflowTabs: T[];
}>(
(tabGroups, tab) => {
if (overflowTabValuesSet.has(tab.value)) {
tabGroups.overflowTabs.push(tab);
} else {
tabGroups.visibleTabs.push(tab);
}
return tabGroups;
},
{ visibleTabs: [], overflowTabs: [] },
);
const getTabMeasureProps = (tabValue: string) => {
return { [DATA_ATTR_TAB_VALUE]: tabValue };
};
return {
containerRef,
visibleTabs,
overflowTabs,
getTabMeasureProps,
};
};
const calculateTabValues = <T extends TabValue>({
tabs,
availableWidth,
tabWidthByValue,
overflowTriggerWidth,
}: {
tabs: readonly T[];
availableWidth: number;
tabWidthByValue: Readonly<Record<string, number>>;
overflowTriggerWidth: number;
}): string[] => {
const tabWidthByValueMap = new Map<string, number>();
for (const tab of tabs) {
tabWidthByValueMap.set(tab.value, tabWidthByValue[tab.value] ?? 0);
}
const firstOptionalTabIndex = Math.min(
ALWAYS_VISIBLE_TABS_COUNT,
tabs.length,
);
if (firstOptionalTabIndex >= tabs.length) {
return [];
}
const alwaysVisibleTabs = tabs.slice(0, firstOptionalTabIndex);
const optionalTabs = tabs.slice(firstOptionalTabIndex);
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
return total + (tabWidthByValueMap.get(tab.value) ?? 0);
}, 0);
const firstTabIndex = findFirstTabIndex({
optionalTabs,
optionalTabWidths: optionalTabs.map((tab) => {
return tabWidthByValueMap.get(tab.value) ?? 0;
}),
startingUsedWidth: alwaysVisibleWidth,
availableWidth,
overflowTriggerWidth,
});
if (firstTabIndex === -1) {
return [];
}
return optionalTabs
.slice(firstTabIndex)
.map((overflowTab) => overflowTab.value);
};
const measureTabWidths = <T extends TabValue>({
tabs,
container,
previousTabWidthByValue,
}: {
tabs: readonly T[];
container: HTMLDivElement;
previousTabWidthByValue: Readonly<Record<string, number>>;
}): Record<string, number> => {
const nextTabWidthByValue = { ...previousTabWidthByValue };
for (const tab of tabs) {
const tabElement = container.querySelector<HTMLElement>(
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
);
if (tabElement) {
nextTabWidthByValue[tab.value] = tabElement.offsetWidth;
}
}
return nextTabWidthByValue;
};
const findFirstTabIndex = ({
optionalTabs,
optionalTabWidths,
startingUsedWidth,
availableWidth,
overflowTriggerWidth,
}: {
optionalTabs: readonly TabValue[];
optionalTabWidths: readonly number[];
startingUsedWidth: number;
availableWidth: number;
overflowTriggerWidth: number;
}): number => {
const result = optionalTabs.reduce(
(acc, _tab, index) => {
if (acc.firstTabIndex !== -1) {
return acc;
}
const tabWidth = optionalTabWidths[index] ?? 0;
const hasMoreTabs = index < optionalTabs.length - 1;
// Reserve kebab trigger width whenever additional tabs remain.
const widthNeeded =
acc.usedWidth + tabWidth + (hasMoreTabs ? overflowTriggerWidth : 0);
if (widthNeeded <= availableWidth) {
return {
usedWidth: acc.usedWidth + tabWidth,
firstTabIndex: -1,
};
}
return {
usedWidth: acc.usedWidth,
firstTabIndex: index,
};
},
{ usedWidth: startingUsedWidth, firstTabIndex: -1 },
);
return result.firstTabIndex;
};
const areStringArraysEqual = (
left: readonly string[],
right: readonly string[],
): boolean => {
return (
left.length === right.length &&
left.every((value, index) => value === right[index])
);
};
@@ -1,157 +0,0 @@
import {
type RefObject,
useCallback,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from "react";
type TabLike = {
value: string;
};
type UseTabOverflowKebabMenuOptions<TTab extends TabLike> = {
tabs: readonly TTab[];
enabled: boolean;
isActive: boolean;
alwaysVisibleTabsCount?: number;
overflowTriggerWidthPx?: number;
};
type UseTabOverflowKebabMenuResult<TTab extends TabLike> = {
containerRef: RefObject<HTMLDivElement | null>;
visibleTabs: TTab[];
overflowTabs: TTab[];
getTabMeasureProps: (tabValue: string) => Record<string, string>;
};
const DATA_ATTR_TAB_VALUE = "data-tab-overflow-item-value";
export const useTabOverflowKebabMenu = <TTab extends TabLike>({
tabs,
enabled,
isActive,
alwaysVisibleTabsCount = 1,
overflowTriggerWidthPx = 44,
}: UseTabOverflowKebabMenuOptions<TTab>): UseTabOverflowKebabMenuResult<TTab> => {
const containerRef = useRef<HTMLDivElement>(null);
const tabWidthByValueRef = useRef<Record<string, number>>({});
const [overflowTabValues, setOverflowTabValues] = useState<string[]>([]);
const recalculateOverflow = useCallback(() => {
if (!enabled) {
setOverflowTabValues([]);
return;
}
const container = containerRef.current;
if (!container) {
return;
}
for (const tab of tabs) {
const tabElement = container.querySelector<HTMLElement>(
`[${DATA_ATTR_TAB_VALUE}="${tab.value}"]`,
);
if (tabElement) {
tabWidthByValueRef.current[tab.value] = tabElement.offsetWidth;
}
}
const alwaysVisibleTabs = tabs.slice(0, alwaysVisibleTabsCount);
const optionalTabs = tabs.slice(alwaysVisibleTabsCount);
if (optionalTabs.length === 0) {
setOverflowTabValues([]);
return;
}
const alwaysVisibleWidth = alwaysVisibleTabs.reduce((total, tab) => {
return total + (tabWidthByValueRef.current[tab.value] ?? 0);
}, 0);
const availableWidth = container.clientWidth;
let usedWidth = alwaysVisibleWidth;
const nextOverflowValues: string[] = [];
for (let i = 0; i < optionalTabs.length; i++) {
const tab = optionalTabs[i];
const tabWidth = tabWidthByValueRef.current[tab.value] ?? 0;
const hasMoreTabsAfterCurrent = i < optionalTabs.length - 1;
const widthNeeded =
usedWidth +
tabWidth +
(hasMoreTabsAfterCurrent ? overflowTriggerWidthPx : 0);
if (widthNeeded <= availableWidth) {
usedWidth += tabWidth;
continue;
}
nextOverflowValues.push(
...optionalTabs.slice(i).map((overflowTab) => overflowTab.value),
);
break;
}
setOverflowTabValues((currentValues) => {
if (
currentValues.length === nextOverflowValues.length &&
currentValues.every(
(value, index) => value === nextOverflowValues[index],
)
) {
return currentValues;
}
return nextOverflowValues;
});
}, [alwaysVisibleTabsCount, enabled, overflowTriggerWidthPx, tabs]);
useLayoutEffect(() => {
if (!isActive) {
return;
}
recalculateOverflow();
}, [isActive, recalculateOverflow]);
useEffect(() => {
if (!isActive) {
return;
}
const container = containerRef.current;
if (!container) {
return;
}
const observer = new ResizeObserver(() => {
recalculateOverflow();
});
observer.observe(container);
return () => observer.disconnect();
}, [isActive, recalculateOverflow]);
const overflowTabValuesSet = useMemo(
() => new Set(overflowTabValues),
[overflowTabValues],
);
const visibleTabs = useMemo(
() => tabs.filter((tab) => !overflowTabValuesSet.has(tab.value)),
[tabs, overflowTabValuesSet],
);
const overflowTabs = useMemo(
() => tabs.filter((tab) => overflowTabValuesSet.has(tab.value)),
[tabs, overflowTabValuesSet],
);
const getTabMeasureProps = useCallback((tabValue: string) => {
return { [DATA_ATTR_TAB_VALUE]: tabValue };
}, []);
return {
containerRef,
visibleTabs,
overflowTabs,
getTabMeasureProps,
};
};
+72
View File
@@ -258,6 +258,78 @@ describe(usePaginatedQuery.name, () => {
});
});
describe("Capped count behavior", () => {
const mockQueryKey = vi.fn(() => ["mock"]);
// Returns count 2001 (capped) with items on pages up to page 84
// (84 * 25 = 2100 items total).
const mockCappedQueryFn = vi.fn(({ pageNumber, limit }) => {
const totalItems = 2100;
const offset = (pageNumber - 1) * limit;
// Returns 0 items when the requested page is past the end, simulating
// an empty server response.
const itemsOnPage = Math.max(0, Math.min(limit, totalItems - offset));
return Promise.resolve({
data: new Array(itemsOnPage).fill(pageNumber),
count: 2001,
count_cap: 2000,
});
});
it("Caps totalRecords at 2000 when count exceeds cap", async () => {
const { result } = await render({
queryKey: mockQueryKey,
queryFn: mockCappedQueryFn,
});
await waitFor(() => expect(result.current.isSuccess).toBe(true));
expect(result.current.totalRecords).toBe(2000);
});
it("hasNextPage is true when count is capped", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=80",
);
await waitFor(() => expect(result.current.isSuccess).toBe(true));
expect(result.current.hasNextPage).toBe(true);
});
it("hasPreviousPage is true when count is capped and page is beyond cap", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=83",
);
await waitFor(() => expect(result.current.isSuccess).toBe(true));
expect(result.current.hasPreviousPage).toBe(true);
});
it("Does not redirect to last page when count is capped and page is valid", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=83",
);
await waitFor(() => expect(result.current.isSuccess).toBe(true));
// Should stay on page 83 — not redirect to page 80.
expect(result.current.currentPage).toBe(83);
});
it("Redirects to last known page when navigating beyond actual data", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=999",
);
// Page 999 has no items. Should redirect to page 81
// (ceil(2001 / 25) = 81), the last page guaranteed to
// have data.
await waitFor(() => expect(result.current.currentPage).toBe(81));
});
});
describe("Passing in searchParams property", () => {
const mockQueryKey = vi.fn(() => ["mock"]);
const mockQueryFn = vi.fn(({ pageNumber, limit }) =>
+50 -8
View File
@@ -144,16 +144,44 @@ export function usePaginatedQuery<
placeholderData: keepPreviousData,
});
const totalRecords = query.data?.count;
const totalPages =
totalRecords !== undefined ? Math.ceil(totalRecords / limit) : undefined;
const count = query.data?.count;
const countCap = query.data?.count_cap;
const countIsCapped =
countCap !== undefined &&
countCap > 0 &&
count !== undefined &&
count > countCap;
const totalRecords = countIsCapped ? countCap : count;
let totalPages =
totalRecords !== undefined
? Math.max(
Math.ceil(totalRecords / limit),
// True count is not known; let them navigate forward
// until they hit an empty page (checked below).
countIsCapped ? currentPage : 0,
)
: undefined;
// When the true count is unknown, the user can navigate past
// all actual data. If that happens, we need to redirect (via
// updatePageIfInvalid) to the last page guaranteed to be not
// empty.
const pageIsEmpty =
query.data != null &&
!Object.values(query.data).some((v) => Array.isArray(v) && v.length > 0);
if (pageIsEmpty) {
totalPages = count !== undefined ? Math.ceil(count / limit) : 1;
}
const hasNextPage =
totalRecords !== undefined && limit + currentPageOffset < totalRecords;
totalRecords !== undefined &&
((countIsCapped && !pageIsEmpty) ||
limit + currentPageOffset < totalRecords);
const hasPreviousPage =
totalRecords !== undefined &&
currentPage > 1 &&
currentPageOffset - limit < totalRecords;
((countIsCapped && !pageIsEmpty) ||
currentPageOffset - limit < totalRecords);
const queryClient = useQueryClient();
const prefetchPage = useEffectEvent((newPage: number) => {
@@ -224,10 +252,14 @@ export function usePaginatedQuery<
});
useEffect(() => {
if (!query.isFetching && totalPages !== undefined) {
if (
!query.isFetching &&
totalPages !== undefined &&
currentPage > totalPages
) {
void updatePageIfInvalid(totalPages);
}
}, [updatePageIfInvalid, query.isFetching, totalPages]);
}, [updatePageIfInvalid, query.isFetching, totalPages, currentPage]);
const onPageChange = (newPage: number) => {
// Page 1 is the only page that can be safely navigated to without knowing
@@ -236,7 +268,12 @@ export function usePaginatedQuery<
return;
}
const cleanedInput = clamp(Math.trunc(newPage), 1, totalPages ?? 1);
// If the true count is unknown, we allow navigating past the
// known page range.
const upperBound = countIsCapped
? Number.MAX_SAFE_INTEGER
: (totalPages ?? 1);
const cleanedInput = clamp(Math.trunc(newPage), 1, upperBound);
if (Number.isNaN(cleanedInput)) {
return;
}
@@ -274,6 +311,7 @@ export function usePaginatedQuery<
totalRecords: totalRecords as number,
totalPages: totalPages as number,
currentOffsetStart: currentPageOffset + 1,
countIsCapped,
}
: {
isSuccess: false,
@@ -282,6 +320,7 @@ export function usePaginatedQuery<
totalRecords: undefined,
totalPages: undefined,
currentOffsetStart: undefined,
countIsCapped: false as const,
}),
};
@@ -323,6 +362,7 @@ export type PaginationResultInfo = {
totalRecords: undefined;
totalPages: undefined;
currentOffsetStart: undefined;
countIsCapped: false;
}
| {
isSuccess: true;
@@ -331,6 +371,7 @@ export type PaginationResultInfo = {
totalRecords: number;
totalPages: number;
currentOffsetStart: number;
countIsCapped: boolean;
}
);
@@ -417,6 +458,7 @@ type QueryPageParamsWithPayload<TPayload = never> = QueryPageParams & {
*/
export type PaginatedData = {
count: number;
count_cap?: number;
};
/**
@@ -2,7 +2,7 @@ import type { Meta, StoryObj } from "@storybook/react-vite";
import { expect, spyOn, userEvent, waitFor, within } from "storybook/test";
import { API } from "#/api/api";
import { workspaceAgentContainersKey } from "#/api/queries/workspaces";
import type * as TypesGen from "#/api/typesGenerated";
import type { WorkspaceAgentLogSource } from "#/api/typesGenerated";
import { getPreferredProxy } from "#/contexts/ProxyContext";
import { chromatic } from "#/testHelpers/chromatic";
import * as M from "#/testHelpers/entities";
@@ -76,6 +76,8 @@ const defaultAgentMetadata = [
},
];
const fixedLogTimestamp = "2021-05-05T00:00:00.000Z";
const logs = [
"\x1b[91mCloning Git repository...",
"\x1b[2;37;41mStarting Docker Daemon...",
@@ -87,10 +89,10 @@ const logs = [
level: "info",
output: line,
source_id: M.MockWorkspaceAgentLogSource.id,
created_at: new Date().toISOString(),
created_at: fixedLogTimestamp,
}));
const installScriptLogSource: TypesGen.WorkspaceAgentLogSource = {
const installScriptLogSource: WorkspaceAgentLogSource = {
...M.MockWorkspaceAgentLogSource,
id: "f2ee4b8d-b09d-4f4e-a1f1-5e4adf7d53bb",
display_name: "Install Script",
@@ -102,60 +104,24 @@ const tabbedLogs = [
level: "info",
output: "startup: preparing workspace",
source_id: M.MockWorkspaceAgentLogSource.id,
created_at: new Date().toISOString(),
created_at: fixedLogTimestamp,
},
{
id: 101,
level: "info",
output: "install: pnpm install",
source_id: installScriptLogSource.id,
created_at: new Date().toISOString(),
created_at: fixedLogTimestamp,
},
{
id: 102,
level: "info",
output: "install: setup complete",
source_id: installScriptLogSource.id,
created_at: new Date().toISOString(),
created_at: fixedLogTimestamp,
},
];
const overflowLogSources: TypesGen.WorkspaceAgentLogSource[] = [
M.MockWorkspaceAgentLogSource,
{
...M.MockWorkspaceAgentLogSource,
id: "58f5db69-5f78-496f-bce1-0686f5525aa1",
display_name: "code-server",
icon: "/icon/code.svg",
},
{
...M.MockWorkspaceAgentLogSource,
id: "f39d758c-bce2-4f41-8d70-58fdb1f0f729",
display_name: "Install and start AgentAPI",
icon: "/icon/claude.svg",
},
{
...M.MockWorkspaceAgentLogSource,
id: "bf7529b8-1787-4a20-b54f-eb894680e48f",
display_name: "Mux",
icon: "/icon/mux.svg",
},
{
...M.MockWorkspaceAgentLogSource,
id: "0d6ebde6-c534-4551-9f91-bfd98bfb04f4",
display_name: "Portable Desktop",
icon: "/icon/portable-desktop.svg",
},
];
const overflowLogs = overflowLogSources.map((source, index) => ({
id: 200 + index,
level: "info",
output: `${source.display_name}: line`,
source_id: source.id,
created_at: new Date().toISOString(),
}));
const meta: Meta<typeof AgentRow> = {
title: "components/AgentRow",
component: AgentRow,
@@ -438,44 +404,3 @@ export const LogsTabs: Story = {
await expect(canvas.getByText("install: pnpm install")).toBeVisible();
},
};
export const LogsTabsOverflow: Story = {
args: {
agent: {
...M.MockWorkspaceAgentReady,
logs_length: overflowLogs.length,
log_sources: overflowLogSources,
},
},
parameters: {
webSocket: [
{
event: "message",
data: JSON.stringify(overflowLogs),
},
],
},
render: (args) => (
<div className="max-w-[320px]">
<AgentRow {...args} />
</div>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const page = within(canvasElement.ownerDocument.body);
await userEvent.click(canvas.getByRole("button", { name: "Logs" }));
await userEvent.click(
canvas.getByRole("button", { name: "More log tabs" }),
);
const overflowItems = await page.findAllByRole("menuitemradio");
const selectedItem = overflowItems[0];
const selectedSource = selectedItem.textContent;
if (!selectedSource) {
throw new Error("Overflow menu item must have text content.");
}
await userEvent.click(selectedItem);
await waitFor(() =>
expect(canvas.getByText(`${selectedSource}: line`)).toBeVisible(),
);
},
};
+69 -60
View File
@@ -8,10 +8,9 @@ import {
} from "lucide-react";
import {
type FC,
useCallback,
type ReactNode,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from "react";
@@ -42,7 +41,7 @@ import {
TabsList,
TabsTrigger,
} from "#/components/Tabs/Tabs";
import { useTabOverflowKebabMenu } from "#/components/Tabs/utils";
import { useKebabMenu } from "#/components/Tabs/utils/useKebabMenu";
import { useProxy } from "#/contexts/ProxyContext";
import { useClipboard } from "#/hooks/useClipboard";
import { useFeatureVisibility } from "#/modules/dashboard/useFeatureVisibility";
@@ -162,7 +161,7 @@ export const AgentRow: FC<AgentRowProps> = ({
// This is a bit of a hack on the react-window API to get the scroll position.
// If we're scrolled to the bottom, we want to keep the list scrolled to the bottom.
// This makes it feel similar to a terminal that auto-scrolls downwards!
const handleLogScroll = useCallback((props: ListOnScrollProps) => {
const handleLogScroll = (props: ListOnScrollProps) => {
if (
props.scrollOffset === 0 ||
props.scrollUpdateWasRequested ||
@@ -179,7 +178,7 @@ export const AgentRow: FC<AgentRowProps> = ({
logListDivRef.current.scrollHeight -
(props.scrollOffset + parent.clientHeight);
setBottomOfLogs(distanceFromBottom < AGENT_LOG_LINE_HEIGHT);
}, []);
};
const devcontainers = useAgentContainers(agent);
@@ -211,59 +210,56 @@ export const AgentRow: FC<AgentRowProps> = ({
);
const [selectedLogTab, setSelectedLogTab] = useState("all");
const logTabs = useMemo(() => {
const sourceLogTabs = agent.log_sources
.filter((logSource) => {
// Remove the logSources that have no entries.
return agentLogs.some(
(log) =>
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
);
})
.map((logSource) => ({
// Show the icon for the log source if it has one.
// In the startup script case, we show a bespoke play icon.
startIcon: logSource.icon ? (
<ExternalImage
src={logSource.icon}
alt=""
className="size-icon-xs shrink-0"
/>
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
<PlayIcon className="size-icon-xs shrink-0" />
) : null,
title: logSource.display_name,
value: logSource.id,
}));
const startupScriptLogTab = sourceLogTabs.find(
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
);
const sortedSourceLogTabs = sourceLogTabs
.filter((tab) => tab !== startupScriptLogTab)
.sort((a, b) => a.title.localeCompare(b.title));
return [
{
title: "All Logs",
value: "all",
},
...(startupScriptLogTab ? [startupScriptLogTab] : []),
...sortedSourceLogTabs,
] as {
startIcon?: React.ReactNode;
title: string;
value: string;
}[];
}, [agent.log_sources, agentLogs]);
const sourceLogTabs = agent.log_sources
.filter((logSource) => {
// Remove the logSources that have no entries.
return agentLogs.some(
(log) =>
log.source_id === logSource.id && (log.output?.length ?? 0) > 0,
);
})
.map((logSource) => ({
// Show the icon for the log source if it has one.
// In the startup script case, we show a bespoke play icon.
startIcon: logSource.icon ? (
<ExternalImage
src={logSource.icon}
alt=""
className="size-icon-xs shrink-0"
/>
) : logSource.display_name === STARTUP_SCRIPT_DISPLAY_NAME ? (
<PlayIcon className="size-icon-xs shrink-0" />
) : null,
title: logSource.display_name,
value: logSource.id,
}));
const startupScriptLogTab = sourceLogTabs.find(
(tab) => tab.title === STARTUP_SCRIPT_DISPLAY_NAME,
);
const sortedSourceLogTabs = sourceLogTabs
.filter((tab) => tab !== startupScriptLogTab)
.sort((a, b) => a.title.localeCompare(b.title));
const logTabs: {
startIcon?: ReactNode;
title: string;
value: string;
}[] = [
{
title: "All Logs",
value: "all",
},
...(startupScriptLogTab ? [startupScriptLogTab] : []),
...sortedSourceLogTabs,
];
const {
containerRef: logTabsListContainerRef,
visibleTabs: visibleLogTabs,
overflowTabs: overflowLogTabs,
getTabMeasureProps,
} = useTabOverflowKebabMenu({
} = useKebabMenu({
tabs: logTabs,
enabled: true,
isActive: showLogs,
alwaysVisibleTabsCount: 1,
});
const overflowLogTabValuesSet = new Set(
overflowLogTabs.map((tab) => tab.value),
@@ -279,16 +275,29 @@ export const AgentRow: FC<AgentRowProps> = ({
level: log.level,
sourceId: log.source_id,
}));
const allLogsText = agentLogs.map((log) => log.output).join("\n");
const selectedLogsText = selectedLogs.map((log) => log.output).join("\n");
const hasSelectedLogs = selectedLogs.length > 0;
const hasAnyLogs = agentLogs.length > 0;
const { showCopiedSuccess, copyToClipboard } = useClipboard();
const selectedLogTabTitle =
logTabs.find((tab) => tab.value === selectedLogTab)?.title ?? "Logs";
const sanitizedTabTitle = selectedLogTabTitle
.toLowerCase()
.replaceAll(/[^a-z0-9]+/g, "-")
.replaceAll(/(^-|-$)/g, "");
const logFilenameSuffix = sanitizedTabTitle || "logs";
const downloadableLogSets = logTabs
.filter((tab) => tab.value !== "all")
.map((tab) => {
const logsText = agentLogs
.filter((log) => log.source_id === tab.value)
.map((log) => log.output)
.join("\n");
const filenameSuffix = tab.title
.toLowerCase()
.replaceAll(/[^a-z0-9]+/g, "-")
.replaceAll(/(^-|-$)/g, "");
return {
label: tab.title,
filenameSuffix: filenameSuffix || tab.value,
logsText,
startIcon: tab.startIcon,
};
});
return (
<div
@@ -547,9 +556,9 @@ export const AgentRow: FC<AgentRowProps> = ({
</Button>
<DownloadSelectedAgentLogsButton
agentName={agent.name}
filenameSuffix={logFilenameSuffix}
logsText={selectedLogsText}
disabled={!hasSelectedLogs}
logSets={downloadableLogSets}
allLogsText={allLogsText}
disabled={!hasAnyLogs}
/>
</div>
</div>
@@ -1,14 +1,27 @@
import { saveAs } from "file-saver";
import { DownloadIcon } from "lucide-react";
import { type FC, useState } from "react";
import { ChevronDownIcon, DownloadIcon, PackageIcon } from "lucide-react";
import { type FC, type ReactNode, useState } from "react";
import { toast } from "sonner";
import { getErrorDetail } from "#/api/errors";
import { Button } from "#/components/Button/Button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "#/components/DropdownMenu/DropdownMenu";
type DownloadableLogSet = {
label: string;
filenameSuffix: string;
logsText: string;
startIcon?: ReactNode;
};
type DownloadSelectedAgentLogsButtonProps = {
agentName: string;
filenameSuffix: string;
logsText: string;
logSets: readonly DownloadableLogSet[];
allLogsText: string;
disabled?: boolean;
download?: (file: Blob, filename: string) => void | Promise<void>;
};
@@ -17,13 +30,13 @@ export const DownloadSelectedAgentLogsButton: FC<
DownloadSelectedAgentLogsButtonProps
> = ({
agentName,
filenameSuffix,
logsText,
logSets,
allLogsText,
disabled = false,
download = saveAs,
}) => {
const [isDownloading, setIsDownloading] = useState(false);
const handleDownload = async () => {
const downloadLogs = async (logsText: string, filenameSuffix: string) => {
try {
setIsDownloading(true);
const file = new Blob([logsText], { type: "text/plain" });
@@ -37,15 +50,40 @@ export const DownloadSelectedAgentLogsButton: FC<
}
};
const hasAllLogs = allLogsText.length > 0;
return (
<Button
variant="subtle"
size="sm"
disabled={disabled || isDownloading}
onClick={handleDownload}
>
<DownloadIcon />
{isDownloading ? "Downloading..." : "Download logs"}
</Button>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="subtle" size="sm" disabled={disabled || isDownloading}>
<DownloadIcon />
{isDownloading ? "Downloading..." : "Download logs"}
<ChevronDownIcon className="size-icon-sm" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
disabled={!hasAllLogs}
onSelect={() => {
downloadLogs(allLogsText, "all-logs");
}}
>
<PackageIcon />
Download all logs
</DropdownMenuItem>
{logSets.map((logSet) => (
<DropdownMenuItem
key={logSet.filenameSuffix}
disabled={logSet.logsText.length === 0}
onSelect={() => {
downloadLogs(logSet.logsText, logSet.filenameSuffix);
}}
>
{logSet.startIcon}
<span>Download {logSet.label}</span>
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
);
};
+24 -18
View File
@@ -646,17 +646,24 @@ const AgentChatPage: FC = () => {
const isRegenerateTitleDisabled = isArchived || isRegeneratingThisChat;
const chatLastModelConfigID = chatRecord?.last_model_config_id;
const sendMutation = useMutation(
// Destructure mutation results directly so the React Compiler
// tracks stable primitives/functions instead of the whole result
// object (TanStack Query v5 recreates it every render via object
// spread). Keeping no intermediate variable prevents future code
// from accidentally closing over the unstable object.
const { isPending: isSendPending, mutateAsync: sendMessage } = useMutation(
createChatMessage(queryClient, agentId ?? ""),
);
const editMutation = useMutation(editChatMessage(queryClient, agentId ?? ""));
const interruptMutation = useMutation(
const { isPending: isEditPending, mutateAsync: editMessage } = useMutation(
editChatMessage(queryClient, agentId ?? ""),
);
const { isPending: isInterruptPending, mutateAsync: interrupt } = useMutation(
interruptChat(queryClient, agentId ?? ""),
);
const deleteQueuedMutation = useMutation(
const { mutateAsync: deleteQueuedMessage } = useMutation(
deleteChatQueuedMessage(queryClient, agentId ?? ""),
);
const promoteQueuedMutation = useMutation(
const { mutateAsync: promoteQueuedMessage } = useMutation(
promoteChatQueuedMessage(queryClient, agentId ?? ""),
);
@@ -754,9 +761,7 @@ const AgentChatPage: FC = () => {
hasUserFixableModelProviders,
});
const isSubmissionPending =
sendMutation.isPending ||
editMutation.isPending ||
interruptMutation.isPending;
isSendPending || isEditPending || isInterruptPending;
const isInputDisabled = !hasModelOptions || isArchived;
const handleUsageLimitError = (error: unknown): void => {
@@ -842,7 +847,7 @@ const AgentChatPage: FC = () => {
setPendingEditMessageId(editedMessageID);
scrollToBottomRef.current?.();
try {
await editMutation.mutateAsync({
await editMessage({
messageId: editedMessageID,
req: request,
});
@@ -873,9 +878,9 @@ const AgentChatPage: FC = () => {
// For queued sends the WebSocket status events handle
// clearing; for non-queued sends we clear explicitly
// below. Clearing eagerly causes a visible cutoff.
let response: Awaited<ReturnType<typeof sendMutation.mutateAsync>>;
let response: Awaited<ReturnType<typeof sendMessage>>;
try {
response = await sendMutation.mutateAsync(request);
response = await sendMessage(request);
} catch (error) {
handleUsageLimitError(error);
throw error;
@@ -908,10 +913,10 @@ const AgentChatPage: FC = () => {
};
const handleInterrupt = () => {
if (!agentId || interruptMutation.isPending) {
if (!agentId || isInterruptPending) {
return;
}
void interruptMutation.mutateAsync();
void interrupt();
};
const handleDeleteQueuedMessage = async (id: number) => {
@@ -920,7 +925,7 @@ const AgentChatPage: FC = () => {
previousQueuedMessages.filter((message) => message.id !== id),
);
try {
await deleteQueuedMutation.mutateAsync(id);
await deleteQueuedMessage(id);
} catch (error) {
store.setQueuedMessages(previousQueuedMessages);
throw error;
@@ -941,7 +946,7 @@ const AgentChatPage: FC = () => {
store.clearStreamError();
store.setChatStatus("pending");
try {
const promotedMessage = await promoteQueuedMutation.mutateAsync(id);
const promotedMessage = await promoteQueuedMessage(id);
// Insert the promoted message into the store and cache
// immediately so it appears in the timeline without
// waiting for the WebSocket to deliver it.
@@ -990,7 +995,8 @@ const AgentChatPage: FC = () => {
? `ssh ${workspaceAgent.name}.${workspace.name}.${workspace.owner_name}.${sshConfigQuery.data.hostname_suffix}`
: undefined;
const generateKeyMutation = useMutation({
// See mutation destructuring comment above (React Compiler).
const { mutate: generateKey } = useMutation({
mutationFn: () => API.getApiKey(),
});
@@ -1005,7 +1011,7 @@ const AgentChatPage: FC = () => {
const repoRoots = Array.from(gitWatcher.repositories.keys()).sort();
const folder = repoRoots[0] ?? workspaceAgent.expanded_directory;
generateKeyMutation.mutate(undefined, {
generateKey(undefined, {
onSuccess: ({ key }) => {
location.href = getVSCodeHref(editor, {
owner: workspace.owner_name,
@@ -1141,7 +1147,7 @@ const AgentChatPage: FC = () => {
compressionThreshold={compressionThreshold}
isInputDisabled={isInputDisabled}
isSubmissionPending={isSubmissionPending}
isInterruptPending={interruptMutation.isPending}
isInterruptPending={isInterruptPending}
isSidebarCollapsed={isSidebarCollapsed}
onToggleSidebarCollapsed={onToggleSidebarCollapsed}
showSidebarPanel={showSidebarPanel}
@@ -0,0 +1,63 @@
import { BookOpenIcon, LoaderIcon, TriangleAlertIcon } from "lucide-react";
import type React from "react";
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "#/components/Tooltip/Tooltip";
import { cn } from "#/utils/cn";
import { Response } from "../Response";
import { ToolCollapsible } from "./ToolCollapsible";
import type { ToolStatus } from "./utils";
export const ReadSkillTool: React.FC<{
label: string;
body: string;
status: ToolStatus;
isError: boolean;
errorMessage?: string;
}> = ({ label, body, status, isError, errorMessage }) => {
const hasContent = body.length > 0;
const isRunning = status === "running";
return (
<ToolCollapsible
className="w-full"
hasContent={hasContent}
header={
<>
<BookOpenIcon className="h-4 w-4 shrink-0 text-content-secondary" />
<span className={cn("text-sm", "text-content-secondary")}>
{isRunning ? `Reading ${label}` : `Read ${label}`}
</span>
{isError && (
<Tooltip>
<TooltipTrigger asChild>
<TriangleAlertIcon className="h-3.5 w-3.5 shrink-0 text-content-secondary" />
</TooltipTrigger>
<TooltipContent>
{errorMessage || "Failed to read skill"}
</TooltipContent>
</Tooltip>
)}
{isRunning && (
<LoaderIcon className="h-3.5 w-3.5 shrink-0 animate-spin motion-reduce:animate-none text-content-secondary" />
)}
</>
}
>
{body && (
<ScrollArea
className="mt-1.5 rounded-md border border-solid border-border-default"
viewportClassName="max-h-64"
scrollBarClassName="w-1.5"
>
<div className="px-3 py-2">
<Response>{body}</Response>
</div>
</ScrollArea>
)}
</ToolCollapsible>
);
};
@@ -1342,6 +1342,12 @@ export const ReadSkillCompleted: Story = {
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.getByText(/Read skill deep-review/)).toBeInTheDocument();
// Expand the collapsible to verify markdown body renders.
const toggle = canvas.getByRole("button");
await userEvent.click(toggle);
await waitFor(() => {
expect(canvas.getByText("Deep Review Skill")).toBeInTheDocument();
});
},
};
@@ -1393,6 +1399,12 @@ export const ReadSkillFileCompleted: Story = {
expect(
canvas.getByText(/Read deep-review\/roles\/security-reviewer\.md/),
).toBeInTheDocument();
// Expand the collapsible to verify markdown content renders.
const toggle = canvas.getByRole("button");
await userEvent.click(toggle);
await waitFor(() => {
expect(canvas.getByText("Security Reviewer Role")).toBeInTheDocument();
});
},
};
@@ -23,6 +23,7 @@ import { ListTemplatesTool } from "./ListTemplatesTool";
import { ProcessOutputTool } from "./ProcessOutputTool";
import { ProposePlanTool } from "./ProposePlanTool";
import { ReadFileTool } from "./ReadFileTool";
import { ReadSkillTool } from "./ReadSkillTool";
import { ReadTemplateTool } from "./ReadTemplateTool";
import { SubagentTool } from "./SubagentTool";
import { ToolCollapsible } from "./ToolCollapsible";
@@ -210,6 +211,55 @@ const ReadFileRenderer: FC<ToolRendererProps> = ({
);
};
const ReadSkillRenderer: FC<ToolRendererProps> = ({
status,
args,
result,
isError,
}) => {
const parsedArgs = parseArgs(args);
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
const rec = asRecord(result);
const body = rec ? asString(rec.body) : "";
return (
<ReadSkillTool
label={skillName ? `skill ${skillName}` : "skill"}
body={body}
status={status}
isError={isError}
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
/>
);
};
const ReadSkillFileRenderer: FC<ToolRendererProps> = ({
status,
args,
result,
isError,
}) => {
const parsedArgs = parseArgs(args);
const skillName = parsedArgs ? asString(parsedArgs.name) : "";
const filePath = parsedArgs ? asString(parsedArgs.path) : "";
const label =
skillName && filePath
? `${skillName}/${filePath}`
: skillName || filePath || "skill file";
const rec = asRecord(result);
const content = rec ? asString(rec.content) : "";
return (
<ReadSkillTool
label={label}
body={content}
status={status}
isError={isError}
errorMessage={rec ? asString(rec.error || rec.message) : undefined}
/>
);
};
const WriteFileRenderer: FC<ToolRendererProps> = ({
status,
args,
@@ -667,6 +717,8 @@ const toolRenderers: Record<string, FC<ToolRendererProps>> = {
create_workspace: CreateWorkspaceRenderer,
list_templates: ListTemplatesRenderer,
read_template: ReadTemplateRenderer,
read_skill: ReadSkillRenderer,
read_skill_file: ReadSkillFileRenderer,
spawn_agent: SubagentRenderer,
wait_agent: SubagentRenderer,
message_agent: SubagentRenderer,
@@ -255,7 +255,7 @@ export const ProviderAccordionCards: Story = {
expect(body.queryByText("OpenAI")).not.toBeInTheDocument();
await userEvent.click(body.getByRole("button", { name: /OpenRouter/i }));
await expect(body.getByLabelText("Base URL")).toBeInTheDocument();
await expect(await body.findByLabelText("Base URL")).toBeInTheDocument();
},
};
@@ -462,6 +462,9 @@ export const CreateAndUpdateProvider: Story = {
await waitFor(() => {
expect(args.onCreateProvider).toHaveBeenCalledTimes(1);
});
await waitFor(() => {
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
});
expect(args.onCreateProvider).toHaveBeenCalledWith(
expect.objectContaining({
provider: "openai",
+33 -3
View File
@@ -71,6 +71,7 @@ describe("AuditPage", () => {
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog, MockAuditLog2],
count: 2,
count_cap: 0,
});
// When
@@ -90,6 +91,7 @@ describe("AuditPage", () => {
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog],
count: 1,
count_cap: 0,
});
await renderPage();
@@ -114,6 +116,7 @@ describe("AuditPage", () => {
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog],
count: 1,
count_cap: 0,
});
await renderPage();
@@ -140,9 +143,11 @@ describe("AuditPage", () => {
describe("Filtering", () => {
it("filters by URL", async () => {
const getAuditLogsSpy = vi
.spyOn(API, "getAuditLogs")
.mockResolvedValue({ audit_logs: [MockAuditLog], count: 1 });
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog],
count: 1,
count_cap: 0,
});
const query = "resource_type:workspace action:create";
await renderPage({ filter: query });
@@ -173,4 +178,29 @@ describe("AuditPage", () => {
);
});
});
describe("Capped count", () => {
it("shows capped count indicator and navigates to next page with correct offset", async () => {
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog, MockAuditLog2],
count: 2001,
count_cap: 2000,
});
const user = userEvent.setup();
await renderPage();
await screen.findByText(/2,000\+/);
await user.click(screen.getByRole("button", { name: /next page/i }));
await waitFor(() =>
expect(API.getAuditLogs).toHaveBeenLastCalledWith<[AuditLogsRequest]>({
limit: DEFAULT_RECORDS_PER_PAGE,
offset: DEFAULT_RECORDS_PER_PAGE,
q: "",
}),
);
});
});
});
@@ -69,6 +69,7 @@ describe("ConnectionLogPage", () => {
MockDisconnectedSSHConnectionLog,
],
count: 2,
count_cap: 0,
});
// When
@@ -95,6 +96,7 @@ describe("ConnectionLogPage", () => {
.mockResolvedValue({
connection_logs: [MockConnectedSSHConnectionLog],
count: 1,
count_cap: 0,
});
const query = "type:ssh status:ongoing";
+10 -1
View File
@@ -732,7 +732,16 @@ func (l *peerLifecycle) setLostTimer(c *configMaps) {
if l.lostTimer != nil {
l.lostTimer.Stop()
}
ttl := lostTimeout - c.clock.Since(l.lastHandshake)
var ttl time.Duration
if l.lastHandshake.IsZero() {
// Peer has never completed a handshake. Give it the full
// lostTimeout to establish one rather than deleting it
// immediately. A zero lastHandshake just means WireGuard
// hasn't connected yet, not that the peer is gone.
ttl = lostTimeout
} else {
ttl = lostTimeout - c.clock.Since(l.lastHandshake)
}
if ttl <= 0 {
ttl = time.Nanosecond
}
+91
View File
@@ -641,6 +641,97 @@ func TestConfigMaps_updatePeers_lost(t *testing.T) {
_ = testutil.TryReceive(ctx, t, done)
}
func TestConfigMaps_updatePeers_lost_zero_handshake(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
fEng := newFakeEngineConfigurable()
nodePrivateKey := key.NewNode()
nodeID := tailcfg.NodeID(5)
discoKey := key.NewDisco()
uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public(), CoderDNSSuffixFQDN)
defer uut.close()
mClock := quartz.NewMock(t)
uut.clock = mClock
p1ID := uuid.UUID{1}
p1Node := newTestNode(1)
p1n, err := NodeToProto(p1Node)
require.NoError(t, err)
// Respond to the status request from updatePeers(NODE) with no
// handshake information, so lastHandshake stays zero.
expectNoStatus := func() <-chan struct{} {
called := make(chan struct{})
go func() {
select {
case <-ctx.Done():
t.Error("timeout waiting for status")
return
case b := <-fEng.status:
_ = b // don't add any peer
}
select {
case <-ctx.Done():
t.Error("timeout sending done")
case fEng.statusDone <- struct{}{}:
close(called)
}
}()
return called
}
// Add the peer via NODE update — no handshake in status.
s1 := expectNoStatus()
updates := []*proto.CoordinateResponse_PeerUpdate{
{
Id: p1ID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: p1n,
},
}
uut.updatePeers(updates)
nm := testutil.TryReceive(ctx, t, fEng.setNetworkMap)
r := testutil.TryReceive(ctx, t, fEng.reconfig)
require.Len(t, nm.Peers, 1)
require.Len(t, r.wg.Peers, 1)
_ = testutil.TryReceive(ctx, t, s1)
// Mark the peer as LOST, still with no handshake.
s2 := expectNoStatus()
updates[0].Kind = proto.CoordinateResponse_PeerUpdate_LOST
updates[0].Node = nil
uut.updatePeers(updates)
_ = testutil.TryReceive(ctx, t, s2)
// Peer should NOT be removed immediately.
select {
case <-fEng.setNetworkMap:
t.Fatal("should not reprogram")
default:
// OK!
}
// Prepare a status response for when the lost timer fires after
// lostTimeout. Return empty status (no handshake ever happened).
s3 := expectNoStatus()
mClock.Advance(lostTimeout).MustWait(ctx)
_ = testutil.TryReceive(ctx, t, s3)
// Now the peer should be removed.
nm = testutil.TryReceive(ctx, t, fEng.setNetworkMap)
r = testutil.TryReceive(ctx, t, fEng.reconfig)
require.Len(t, nm.Peers, 0)
require.Len(t, r.wg.Peers, 0)
done := make(chan struct{})
go func() {
defer close(done)
uut.close()
}()
_ = testutil.TryReceive(ctx, t, done)
}
func TestConfigMaps_updatePeers_lost_and_found(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
+50 -42
View File
@@ -12,6 +12,8 @@ import (
"sync"
"time"
"golang.org/x/sync/singleflight"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"github.com/tailscale/wireguard-go/tun"
@@ -463,6 +465,8 @@ type Conn struct {
trafficStats *connstats.Statistics
lastNetInfo *tailcfg.NetInfo
awaitReachableGroup singleflight.Group
}
func (c *Conn) GetNetInfo() *tailcfg.NetInfo {
@@ -599,56 +603,60 @@ func (c *Conn) DERPMap() *tailcfg.DERPMap {
// address is reachable. It's the callers responsibility to provide
// a timeout, otherwise this function will block forever.
func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
ctx, cancel := context.WithCancel(ctx)
defer cancel() // Cancel all pending pings on exit.
result, _, _ := c.awaitReachableGroup.Do(ip.String(), func() (interface{}, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel() // Cancel all pending pings on exit.
completedCtx, completed := context.WithCancel(context.Background())
defer completed()
completedCtx, completed := context.WithCancel(context.Background())
defer completed()
run := func() {
// Safety timeout, initially we'll have around 10-20 goroutines
// running in parallel. The exponential backoff will converge
// around ~1 ping / 30s, this means we'll have around 10-20
// goroutines pending towards the end as well.
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
run := func() {
// Safety timeout, initially we'll have around 10-20 goroutines
// running in parallel. The exponential backoff will converge
// around ~1 ping / 30s, this means we'll have around 10-20
// goroutines pending towards the end as well.
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
// For reachability, we use TSMP ping, which pings at the IP layer, and
// therefore requires that wireguard and the netstack are up. If we
// don't wait for wireguard to be up, we could miss a handshake, and it
// might take 5 seconds for the handshake to be retried. A 5s initial
// round trip can set us up for poor TCP performance, since the initial
// round-trip-time sets the initial retransmit timeout.
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
if err == nil {
completed()
// For reachability, we use TSMP ping, which pings at the IP layer,
// and therefore requires that wireguard and the netstack are up.
// If we don't wait for wireguard to be up, we could miss a
// handshake, and it might take 5 seconds for the handshake to be
// retried. A 5s initial round trip can set us up for poor TCP
// performance, since the initial round-trip-time sets the initial
// retransmit timeout.
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
if err == nil {
completed()
}
}
}
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0
eb.InitialInterval = 50 * time.Millisecond
eb.MaxInterval = 30 * time.Second
// Consume the first interval since
// we'll fire off a ping immediately.
_ = eb.NextBackOff()
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0
eb.InitialInterval = 50 * time.Millisecond
eb.MaxInterval = 5 * time.Second
// Consume the first interval since
// we'll fire off a ping immediately.
_ = eb.NextBackOff()
t := backoff.NewTicker(eb)
defer t.Stop()
t := backoff.NewTicker(eb)
defer t.Stop()
go run()
for {
select {
case <-completedCtx.Done():
return true
case <-t.C:
// Pings can take a while, so we can run multiple
// in parallel to return ASAP.
go run()
case <-ctx.Done():
return false
go run()
for {
select {
case <-completedCtx.Done():
return true, nil
case <-t.C:
// Pings can take a while, so we can run multiple
// in parallel to return ASAP.
go run()
case <-ctx.Done():
return false, nil
}
}
}
})
return result.(bool)
}
// Closed is a channel that ends when the connection has