Compare commits

..

75 Commits

Author SHA1 Message Date
Ben Potter 2574576426 fix(coderd): bump embedded PostgreSQL from 13 to 16
PostgreSQL 13 reached EOL in November 2025 and is no longer receiving
updates from the upstream binary provider (zonkyio/embedded-postgres-binaries).

This also fixes embedded PostgreSQL failing to start on ARM64 systems
like Raspberry Pi with:

  error while loading shared libraries: libcrypto.so.1.1:
  ELF load command address/offset not page-aligned

The ARM64 binaries were being corrupted by an outdated patchelf (0.9)
in zonkyio's build pipeline, which broke ELF page alignment. This was
fixed upstream (zonkyio/embedded-postgres-binaries#105) by upgrading
patchelf to 0.15.5, but only for versions 14.21.0+. Since PG 13 is
EOL, no fixed v13 build was ever published.

The CI test helper (scripts/embedded-pg) already uses V16.
2026-03-24 16:42:52 +00:00
Cian Johnston fd1e2f0dd9 fix(coderd/database/dbauthz): skip Accounting check when sub-test filtering (#23281)
- Detect `-testify.m` sub-test filtering in `SetupSuite` and skip the `Accounting` check.

> 🤖 This PR was created with the help of Coder Agents, and was reviewed by my human. 🧑‍💻
2026-03-24 14:58:04 +00:00
Michael Suchacz be5e080de6 fix(site/src/pages/AgentsPage): preserve chat scroll position when away from bottom (#23451)
## Summary

Stabilizes the /agents chat viewport so users can read older messages
without being yanked to the bottom when new content arrives.

## Architecture

Replaces the implicit scroll-follow behavior with
**ResizeObserver-driven
scroll anchoring**:

- **`autoScrollRef`** is the single source of truth. User scrolling away
  from bottom turns it off; scrolling back near bottom or clicking the
  button turns it back on.
- A **content ResizeObserver** on an inner wrapper detects transcript
  growth. When auto-scroll is on, it re-pins to bottom via double-RAF
  (waiting for React commit + layout to settle). When off, it
  compensates `scrollTop` by the height delta to preserve the reading
  position. Sign-aware for both Chrome-style negative and Firefox-style
  positive `flex-col-reverse` scrollTop conventions.
- Compensation is **skipped during pagination** (older messages prepend
  into the overflow direction; the browser preserves scrollTop) and
  **during reflow** from width changes.
- A **container ResizeObserver** re-pins to bottom after viewport
resizes
  (composer growth, panel changes) when auto-scroll is on.
- **`isRestoringScrollRef`** guards against feedback loops from
  programmatic scroll writes. The smooth-scroll guard stays active
  until the scroll handler detects arrival at bottom.

## Files changed

- **AgentDetailView.tsx**: Rewrote `ScrollAnchoredContainer` with the
  new approach.
- **AgentDetailView.stories.tsx**: Refactored `ScrollToBottomButton`
story
  scroll helpers into shared utilities.

## Behavior

- **At bottom + new content**: stays pinned, button hidden.
- **Scrolled up + new content**: reading position preserved, no jump.
- **Viewport resize while pinned**: re-pins to bottom.
- Scroll-to-bottom button and smooth scroll still work.
2026-03-24 15:55:41 +01:00
Michael Suchacz 19e86628da feat: add propose_plan tool for markdown plan proposals (#23452)
Adds a `propose_plan` tool that presents a workspace markdown file as a
dedicated plan card in the agent UI.

The workflow is: the agent uses `write_file`/`edit_files` to build a
plan file (e.g. `/home/coder/PLAN.md`), then calls `propose_plan(path)`
to present it. The backend reads the file via `ReadFile` and the
frontend renders it as an expanded markdown preview card.

**Backend** (`coderd/x/chatd/chattool/proposeplan.go`): new tool
registered as root-chat-only. Validates `.md` suffix, requires an
absolute path, reads raw file content from the workspace agent. Includes
1 MiB size cap.

**Frontend** (`site/src/components/ai-elements/tool/`): dedicated
`ProposePlanTool` component with `ToolCollapsible` + `ScrollArea` +
`Response` markdown renderer, expanded by default. Custom icon
(`ClipboardListIcon`) and filename-based label.

**System prompt** (`coderd/x/chatd/prompt.go`): added `<planning>`
section guiding the agent to research → write plan file → iterate → call
`propose_plan`.
2026-03-24 15:06:22 +01:00
Michael Suchacz 02356c61f6 fix: use previous_response_id chaining for OpenAI store=true follow-ups (#23450)
OpenAI Responses follow-up turns were replaying full assistant/tool
history even when `store=true`, which breaks after reasoning +
provider-executed `web_search` output.

This change persists the OpenAI response ID on assistant messages, then
in `coderd/x/chatd` switches `store=true` follow-ups to
`previous_response_id` chaining with a system + new-user-only prompt.
`store=false` and missing-ID cases still fall back to manual replay.

It also updates the fake OpenAI server and integration coverage for the
chaining contract, and carries the rebased path move to `coderd/x/chatd`
plus the migration renumber needed after rebasing onto `main`.
2026-03-24 14:57:40 +01:00
Steven Masley b9f0c479ac test: migrate TestResourcesMonitor to mocked db instances (#23464) 2026-03-24 08:49:54 -05:00
Michael Suchacz 803cfeb882 fix(site/src/pages/AgentsPage): stabilize remote diff cache keys (#23487)
## Summary
- use React Query's `dataUpdatedAt` as the remote diff cache
invalidation token instead of a component-local counter
- keep the `@pierre/diffs` cache key stable across remounts without a
custom hashing implementation
- preserve targeted coverage for the cache-key helper used by the
/agents remote diff viewer

## Testing
- `cd site && pnpm exec biome check
src/pages/AgentsPage/components/DiffViewer/RemoteDiffPanel.tsx
src/pages/AgentsPage/components/DiffViewer/diffCacheKey.ts
src/pages/AgentsPage/components/DiffViewer/diffCacheKey.test.ts`
- `cd site && pnpm exec vitest run
src/pages/AgentsPage/components/DiffViewer/diffCacheKey.test.ts
--project=unit`
- `cd site && pnpm exec tsc -p .`
2026-03-24 14:29:53 +01:00
Matt Vollmer 08577006c6 fix(site): improve Workspace Autostop Fallback UX on agents settings page (#23465)
https://github.com/user-attachments/assets/a482ef45-402a-4d86-af59-b1526b2ce3e2

## Summary

Redesigns the **Default Autostop** section on the `/agents` settings
page to clarify that it is a fallback for chat-linked workspaces whose
templates do not define their own autostop policy. Template-level
settings always take priority — this is a backstop, not an override.

## Changes

### UX
- Renamed to **Workspace Autostop Fallback** with clearer description
- Replaced always-visible duration field (confusing `0` in an hours box)
with a **toggle-to-enable** pattern matching the Virtual Desktop section
- Toggle ON auto-saves with a 1-hour default; toggle OFF auto-saves with
0
- Save button is always visible when the toggle is on but disabled until
the user changes the duration value
- Per-section disabled flags — toggling autostop no longer freezes the
Virtual Desktop switch or prompt textareas during the save round-trip

### Reliability
- `onError` rollback on toggle auto-saves so the UI snaps back to server
truth on failure
- Stateful mocks in Storybook stories to prevent race conditions from
instant mock resolution

### Accessibility
- Added `aria-label="Autostop duration"` to the DurationField input
- Updated `DurationField` component to merge external `inputProps` with
internal ones (preserves `step: 1`)

### Stories
- Updated all existing autostop stories for the new toggle-based flow
- Added `DefaultAutostopToggleOff` — tests disabling from an enabled
state
- Added `DefaultAutostopSaveDisabled` — verifies Save button is visible
but disabled when no duration change

---

PR generated with Coder Agents
2026-03-24 09:28:10 -04:00
Kyle Carberry 13241a58ba fix(coderd/x/chatd/mcpclient): use dedicated HTTP transport per MCP connection (#23494)
## Problem

`TestConnectAll_MultipleServers` flakes with:

```
net/http: HTTP/1.x transport connection broken: http: CloseIdleConnections called
```

Each MCP client connection implicitly uses `http.DefaultTransport`. When
`httptest.Server.Close()` runs during parallel test cleanup, it calls
`CloseIdleConnections` on `http.DefaultTransport`, breaking in-flight
connections from other goroutines or parallel tests sharing that
transport.

## Fix

Clone the default transport for each MCP connection via
`http.DefaultTransport.(*http.Transport).Clone()`, passed through
`WithHTTPBasicClient` (StreamableHTTP) and `WithHTTPClient` (SSE). This
scopes idle connection cleanup to a single MCP server so it cannot
disrupt unrelated connections.

Fixes coder/internal#1420
2026-03-24 09:22:45 -04:00
Kyle Carberry 631e4449bb fix: use actual config ID in MCP OAuth2 redirect URI during auto-discovery (#23491)
## Problem

During OAuth2 auto-discovery for MCP servers, the callback URL
registered with the remote authorization server via Dynamic Client
Registration (RFC 7591) contained the literal string `{id}` instead of
the actual config UUID:

```
https://coder.example.com/api/experimental/mcp/servers/{id}/oauth2/callback
```

This happened because the discovery and registration occurred **before**
the database insert that generates the ID. When the user later initiated
the OAuth2 connect flow, the redirect URL used the real UUID, causing
the authorization server to reject it with:

> The provided redirect URIs are not approved for use by this
authorization server

## Fix

Restructure the auto-discovery flow in `createMCPServerConfig` to:

1. **Insert** the MCP server config first (with empty OAuth2 fields) to
get the database-generated UUID
2. **Build** the callback URL with the actual UUID
3. **Perform** OAuth2 discovery and dynamic client registration with the
correct URL
4. **Update** the record with the discovered OAuth2 credentials
5. **Clean up** the record if discovery fails

## Testing

Added regression test
`TestMCPServerConfigsOAuth2AutoDiscovery/RedirectURIContainsRealConfigID`
that:
- Stands up mock auth + MCP servers
- Captures the `redirect_uris` sent during dynamic client registration
- Asserts the URI contains the real config UUID, not `{id}`
- Verifies the full callback path structure

All existing MCP server config tests continue to pass.
2026-03-24 13:04:55 +00:00
Matt Vollmer 76eac82e5b docs: soften security implications intro wording (#23492) 2026-03-24 08:59:33 -04:00
Michael Suchacz 405d81be09 fix(coderd/database): fall back to model names in PR insights (#23490)
Fallback to the configured model name in PR Insights when a model config
has a blank display name.

This updates both the by-model breakdown and recent PR rows, and adds a
regression test for blank display names.
2026-03-24 13:58:29 +01:00
Mathias Fredriksson 1c0442c247 fix(agent/agentfiles): fix replace_all in fuzzy matching mode (#23480)
replace_all in fuzzy mode (passes 2 and 3 of fuzzyReplace) only
replaced the first match. seekLines returned the first match,
spliceLines replaced one range, and there was no loop.

Extract fuzzy pass logic into fuzzyReplaceLines which:
- Returns a 3-tuple (result, matched, error) for clean caller flow
- When replaceAll is true, collects all non-overlapping matches
  then applies replacements from last to first to preserve indices
- When replaceAll is false with multiple matches, returns an error

Add test cases for replace_all with fuzzy trailing whitespace and
fuzzy indent matching.
2026-03-24 14:41:45 +02:00
Mathias Fredriksson 16edcbdd5b fix(agent/agentfiles): follow symlinks in write_file and edit_files (#23478)
Both write_file and edit_files use atomic writes (write to temp
file, then rename). Since rename operates on directory entries, it
replaces symlinks with regular files instead of writing through
the link to the target.

Add resolveSymlink() that uses afero.Lstater/LinkReader to resolve
symlink chains (up to 10 levels) before the atomic write. Both
writeFile and editFile resolve the path before any filesystem
operations, matching the behavior of 'echo content > symlink'.

Gracefully no-ops on filesystems that don't support symlinks (e.g.
MemMapFs used in existing tests).
2026-03-24 12:39:55 +00:00
Kyle Carberry f62f2ffe6a feat(site): add MCP server picker to agent chat UI (#23470)
## Summary

Adds a user-facing MCP server configuration panel to the chat input
toolbar. Users can toggle which MCP servers provide tools for their chat
sessions, and authenticate with OAuth2 servers via popup windows.

## Changes

### New Components
- **`MCPServerPicker`** (`MCPServerPicker.tsx`): Popover-based picker
that appears in the chat input toolbar next to the model selector. Shows
all enabled MCP servers with toggles.
- **`MCPServerPicker.stories.tsx`**: 13 Storybook stories covering all
states.

### Availability Policies
Respects the admin-configured availability for each server:
- **`force_on`**: Always active, toggle disabled, lock icon shown. User
cannot disable.
- **`default_on`**: Pre-selected by default, user can opt out via
toggle.
- **`default_off`**: Not selected by default, user must opt in via
toggle.

### OAuth2 Authentication
For servers with `auth_type: "oauth2"`:
- Shows auth status (connected/not connected)
- "Connect to authenticate" link opens a popup window to
`/api/experimental/mcp/servers/{id}/oauth2/connect`
- Listens for `postMessage` with `{type: "mcp-oauth2-complete"}` from
the callback page
- Same UX pattern as external auth on the Create Workspace screen

### Integration Points
- `AgentChatInput`: MCP picker appears in the toolbar after the model
selector
- `AgentDetail`: Manages MCP selection state, initializes from
`chat.mcp_server_ids` or defaults
- `AgentDetailView` / `AgentDetailContent`: Props plumbed through to
input
- `AgentCreatePage` / `AgentCreateForm`: MCP selection for new chats
- `mcp_server_ids` now sent with `CreateChatMessageRequest` and
`CreateChatRequest`

### Helper
- `getDefaultMCPSelection()`: Computes default selection from
availability policies (`force_on` + `default_on`)

## Storybook Stories
| Story | Description |
|-------|-------------|
| NoServers | No servers - picker hidden |
| AllDisabled | All disabled servers - picker hidden |
| SingleForceOn | Force-on server with locked toggle |
| SingleDefaultOnNoAuth | Default-on with no auth required |
| SingleDefaultOff | Optional server not selected |
| OAuthNeedsAuth | OAuth2 server needing authentication |
| OAuthConnected | OAuth2 server already connected |
| MixedServers | Multiple servers with mixed availability/auth |
| AllConnected | All OAuth2 servers authenticated |
| Disabled | Picker in disabled state |
| WithDisabledServer | Disabled servers filtered out |
| AllOptedOut | All toggled off except force_on |
| OptionalOAuthNeedsAuth | Optional OAuth2 needing auth |
2026-03-24 08:13:18 -04:00
Vlad 2dc3466f07 docs: update JetBrains client downloader link (#23287) 2026-03-24 11:36:20 +00:00
Cian Johnston cbd56d33d4 ci: disable go cache for build jobs to prevent disk space exhaustion (#23484)
Disables Go cache for the setup-go step to workaround depot runner disk space issues.
2026-03-24 11:17:39 +00:00
Mathias Fredriksson b23aed034f fix: make terraform ConvertState fully deterministic (#23459)
All map iterations in ConvertState now use sorted helpers instead of
ranging over Go maps directly. Previously only coder_env and
coder_script were sorted (via sortedResourcesByType). This extends
the pattern to coder_agent, coder_devcontainer, coder_agent_instance,
coder_app, coder_metadata, coder_external_auth, and the main
resource output list.

Also fixes generate.sh writing version.txt to the wrong directory
(resources/ instead of testdata/), which caused the Makefile version
check to silently desync and trigger unnecessary regeneration.

Adds TestConvertStateDeterministic that calls ConvertState 10 times
per fixture and asserts byte-identical JSON output without any
post-hoc sorting.
2026-03-24 11:02:45 +00:00
Ethan 56e80b0a27 fix(site): use HttpResponse constructor for binary mock response (#23474)
## Context

`./scripts/develop.sh` was failing to build in my dogfood workspace
with:

```
src/testHelpers/handlers.ts(346,35): error TS2345: Argument of type 'NonSharedBuffer'
is not assignable to parameter of type 'ArrayBuffer'.
  Type 'Buffer<ArrayBuffer>' is missing the following properties from type 'ArrayBuffer':
  maxByteLength, resizable, resize, detached, and 2 more.
```

## Alternatives considered

**`fileBuffer.buffer`** — `.buffer` gives you the underlying
`ArrayBuffer`, but Node pools small buffers into a shared 8 KB slab. A
`Buffer.from("hello")` has `byteOffset: 1472` and `.buffer.byteLength:
8192` — passing `.buffer` to a `Response` sends all 8,192 bytes instead
of 5. It happens to work for `readFileSync` (dedicated allocation,
offset 0), but breaks silently if someone refactors how the buffer is
constructed.

**`fileBuffer.buffer.slice(byteOffset, byteOffset + byteLength)`** — the
safe version of the above. Always correct, but unnecessarily complex.

**`new HttpResponse(fileBuffer)`** (chosen) — `HttpResponse` extends
`Response`, whose constructor accepts `BodyInit` which includes
`Uint8Array`. When you pass a typed array view, `Response` reads only
the bytes within that view (respecting `byteOffset`/`byteLength`), so
it's safe regardless of pooling. `Buffer` is a `Uint8Array` subclass, so
this just works:

```
pooled = Buffer.from("hello")   → byteOffset: 1472, .buffer: 8192 bytes
new Response(pooled.buffer)     → body: 8192 bytes ✗
new Response(pooled)            → body: 5 bytes    ✓
```
2026-03-24 21:53:56 +11:00
Danny Kopping dba9f68b11 chore!: remove members' ability to read their own interceptions; rationalize RBAC requirements (#23320)
_Disclaimer:_ _produced_ _by_ _Claude_ _Opus_ _4\.6,_ _reviewed_ _by_ _me._

**This is a breaking change.** Users who are not have `owner` or sitewide `auditor` roles will no longer be able to view interceptions.  
Regular users should not need to view this information; in fact, it could be used by a malicious insider to see what information we track and don't track to exfiltrate data or perform actions unobserved.

---

Changed authorization for AI Bridge interception-related operations from system-level permissions to resource-specific permissions. The following functions now authorize against `rbac.ResourceAibridgeInterception` instead of `rbac.ResourceSystem`:

- `ListAIBridgeTokenUsagesByInterceptionIDs`
- `ListAIBridgeToolUsagesByInterceptionIDs`
- `ListAIBridgeUserPromptsByInterceptionIDs`

Updated RBAC roles to grant AI Bridge interception permissions:

- **User/Member roles**: Can create and update AI Bridge interceptions but cannot read them back
- **Service accounts**: Same create/update permissions without read access
- **Owners/Auditors**: Retain full read access to all interceptions

Removed system-level authorization bypass in `populatedAndConvertAIBridgeInterceptions` function, allowing proper resource-level authorization checks.

Updated tests to reflect the new permission model where members cannot view AI Bridge interceptions, even their own, while owners and auditors maintain full visibility.
2026-03-24 12:03:20 +02:00
Jaayden Halko 245ce91199 feat: add bar charts for premium and AI governance add-on license usage (#23442)
Implemented with the help of Cursor agents using Figma MCP

Figma design:
https://www.figma.com/design/klGTlHSPQwI4KBvAMdebrx/Customer-Usage-Controls-for-AI-Governance-Add-On?node-id=448-7658&m=dev

<img width="1143" height="639" alt="Screenshot 2026-03-23 at 20 10 05"
src="https://github.com/user-attachments/assets/300d4d5d-aad2-49a9-bfdd-a329312e5fa8"
/>
2026-03-24 09:07:06 +00:00
Danielle Maywood 5d0734e005 fix(site): diff viewer virtualizer buffer fix and styling polish (#23462) 2026-03-24 09:04:14 +00:00
Danny Kopping 43a1af3cd6 feat: session list API (#23202)
<!--

If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting.

-->

_Disclaimer:_ _initially_ _produced_ _by_ _Claude_ _Opus_ _4\.6,_ _heavily_ _modified_ _and_ _reviewed_ _by_ _me._

Closes https://github.com/coder/internal/issues/1360

Adds a new `/api/v2/aibridge/sessions` API which returns "sessions".

Sessions, as defined in the [RFC](https://www.notion.so/coderhq/AI-Bridge-Sessions-Threads-2ccd579be59280f28021d3baf7472fbe?source=copy_link), are a set of interceptions logically grouped by a session key issued by the client.  
The API design for this endpoint was done in [this doc](https://github.com/coder/internal/issues/1360).

If the client has not provided a session ID, we will revert to the thread root ID, and if that's not present we use the interception's own ID (i.e. a session of a single interception - which is effectively what we show currently in our `/api/v2/aibridge/interceptions` API).

The SQL query looks gnarly but it's relatively simple, and seems to perform well (~200ms) even when I import dogfood's `aibridge_*` tables into my workspace. If we need to improve performance on this later we can investigate materialized views, perhaps, but for now I don't think it's warranted.

---

_The PR looks large but it's got a lot of generated code; the actual changes aren't huge._
2026-03-24 08:58:47 +02:00
Jaayden Halko 3d5d58ec2b fix: make LicenseCard stories use deterministic dates (#23437)
## Summary
- replace dynamic dayjs() date generation in LicenseCard stories with
fixed deterministic timestamps
- preserve story behavior while preventing day-over-day visual drift in
Chromatic
- use shared constants for expired and future date scenarios
2026-03-24 04:38:23 +00:00
dependabot[bot] 37d937554e ci: bump dorny/paths-filter from 3.0.2 to 4.0.1 in the github-actions group (#23435)
Bumps the github-actions group with 1 update:
[dorny/paths-filter](https://github.com/dorny/paths-filter).

Updates `dorny/paths-filter` from 3.0.2 to 4.0.1
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/dorny/paths-filter/releases">dorny/paths-filter's
releases</a>.</em></p>
<blockquote>
<h2>v4.0.1</h2>
<h2>What's Changed</h2>
<ul>
<li>Support merge queue by <a
href="https://github.com/masaru-iritani"><code>@​masaru-iritani</code></a>
in <a
href="https://redirect.github.com/dorny/paths-filter/pull/255">dorny/paths-filter#255</a></li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a
href="https://github.com/masaru-iritani"><code>@​masaru-iritani</code></a>
made their first contribution in <a
href="https://redirect.github.com/dorny/paths-filter/pull/255">dorny/paths-filter#255</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/dorny/paths-filter/compare/v4.0.0...v4.0.1">https://github.com/dorny/paths-filter/compare/v4.0.0...v4.0.1</a></p>
<h2>v4.0.0</h2>
<h2>What's Changed</h2>
<ul>
<li>feat: update action runtime to node24 by <a
href="https://github.com/saschabratton"><code>@​saschabratton</code></a>
in <a
href="https://redirect.github.com/dorny/paths-filter/pull/294">dorny/paths-filter#294</a></li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a
href="https://github.com/saschabratton"><code>@​saschabratton</code></a>
made their first contribution in <a
href="https://redirect.github.com/dorny/paths-filter/pull/294">dorny/paths-filter#294</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/dorny/paths-filter/compare/v3.0.3...v4.0.0">https://github.com/dorny/paths-filter/compare/v3.0.3...v4.0.0</a></p>
<h2>v3.0.3</h2>
<h2>What's Changed</h2>
<ul>
<li>Add missing predicate-quantifier by <a
href="https://github.com/wardpeet"><code>@​wardpeet</code></a> in <a
href="https://redirect.github.com/dorny/paths-filter/pull/279">dorny/paths-filter#279</a></li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a href="https://github.com/wardpeet"><code>@​wardpeet</code></a>
made their first contribution in <a
href="https://redirect.github.com/dorny/paths-filter/pull/279">dorny/paths-filter#279</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/dorny/paths-filter/compare/v3...v3.0.3">https://github.com/dorny/paths-filter/compare/v3...v3.0.3</a></p>
</blockquote>
</details>
<details>
<summary>Changelog</summary>
<p><em>Sourced from <a
href="https://github.com/dorny/paths-filter/blob/master/CHANGELOG.md">dorny/paths-filter's
changelog</a>.</em></p>
<blockquote>
<h1>Changelog</h1>
<h2>v4.0.0</h2>
<ul>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/294">Update
action runtime to node24</a></li>
</ul>
<h2>v3.0.3</h2>
<ul>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/279">Add
missing predicate-quantifier</a></li>
</ul>
<h2>v3.0.2</h2>
<ul>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/224">Add
config parameter for predicate quantifier</a></li>
</ul>
<h2>v3.0.1</h2>
<ul>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/133">Compare
base and ref when token is empty</a></li>
</ul>
<h2>v3.0.0</h2>
<ul>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/210">Update to
Node.js 20</a></li>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/215">Update
all dependencies</a></li>
</ul>
<h2>v2.11.1</h2>
<ul>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/167">Update
<code>@​actions/core</code> to v1.10.0 - Fixes warning about deprecated
set-output</a></li>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/168">Document
need for pull-requests: read permission</a></li>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/164">Updating
to actions/checkout@v3</a></li>
</ul>
<h2>v2.11.0</h2>
<ul>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/157">Set
list-files input parameter as not required</a></li>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/161">Update
Node.js</a></li>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/162">Fix
incorrect handling of Unicode characters in exec()</a></li>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/163">Use
Octokit pagination</a></li>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/160">Updates
real world links</a></li>
</ul>
<h2>v2.10.2</h2>
<ul>
<li><a href="https://redirect.github.com/dorny/paths-filter/pull/91">Fix
getLocalRef() returns wrong ref</a></li>
</ul>
<h2>v2.10.1</h2>
<ul>
<li><a
href="https://redirect.github.com/dorny/paths-filter/pull/85">Improve
robustness of change detection</a></li>
</ul>
<h2>v2.10.0</h2>
<ul>
<li><a href="https://redirect.github.com/dorny/paths-filter/pull/82">Add
ref input parameter</a></li>
<li><a href="https://redirect.github.com/dorny/paths-filter/pull/83">Fix
change detection in PR when pullRequest.changed_files is
incorrect</a></li>
</ul>
<h2>v2.9.3</h2>
<ul>
<li><a href="https://redirect.github.com/dorny/paths-filter/pull/78">Fix
change detection when base is a tag</a></li>
</ul>
<h2>v2.9.2</h2>
<ul>
<li><a href="https://redirect.github.com/dorny/paths-filter/pull/75">Fix
fetching git history</a></li>
</ul>
<h2>v2.9.1</h2>
<ul>
<li><a href="https://redirect.github.com/dorny/paths-filter/pull/74">Fix
fetching git history + fallback to unshallow repo</a></li>
</ul>
<h2>v2.9.0</h2>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/dorny/paths-filter/commit/fbd0ab8f3e69293af611ebaee6363fc25e6d187d"><code>fbd0ab8</code></a>
feat: add merge_group event support</li>
<li><a
href="https://github.com/dorny/paths-filter/commit/efb1da7ce8d89bbc261191e5a2dc1453c3837339"><code>efb1da7</code></a>
feat: add dist/ freshness check to PR workflow</li>
<li><a
href="https://github.com/dorny/paths-filter/commit/d8f7b061b24c30a325ff314b76c37adb05b041ce"><code>d8f7b06</code></a>
Merge pull request <a
href="https://redirect.github.com/dorny/paths-filter/issues/302">#302</a>
from dorny/issue-299</li>
<li><a
href="https://github.com/dorny/paths-filter/commit/addbc147a95845176e1bc013a012fbf1d366389a"><code>addbc14</code></a>
Update README for v4</li>
<li><a
href="https://github.com/dorny/paths-filter/commit/9d7afb8d214ad99e78fbd4247752c4caed2b6e4c"><code>9d7afb8</code></a>
Update CHANGELOG for v4.0.0</li>
<li><a
href="https://github.com/dorny/paths-filter/commit/782470c5d953cae2693d643172b14e01bacb71f3"><code>782470c</code></a>
Merge branch 'releases/v3'</li>
<li><a
href="https://github.com/dorny/paths-filter/commit/d1c1ffe0248fe513906c8e24db8ea791d46f8590"><code>d1c1ffe</code></a>
Update CHANGELOG for v3.0.3</li>
<li><a
href="https://github.com/dorny/paths-filter/commit/ce10459c8b92cd8901166c0a222fbb033ef39365"><code>ce10459</code></a>
Merge pull request <a
href="https://redirect.github.com/dorny/paths-filter/issues/294">#294</a>
from saschabratton/master</li>
<li><a
href="https://github.com/dorny/paths-filter/commit/5f40380c5482e806c81cec080f5192e7234d8fe9"><code>5f40380</code></a>
feat: update action runtime to node24</li>
<li><a
href="https://github.com/dorny/paths-filter/commit/668c092af3649c4b664c54e4b704aa46782f6f7c"><code>668c092</code></a>
Merge pull request <a
href="https://redirect.github.com/dorny/paths-filter/issues/279">#279</a>
from wardpeet/patch-1</li>
<li>Additional commits viewable in <a
href="https://github.com/dorny/paths-filter/compare/de90cc6fb38fc0963ad72b210f1f284cd68cea36...fbd0ab8f3e69293af611ebaee6363fc25e6d187d">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=dorny/paths-filter&package-manager=github_actions&previous-version=3.0.2&new-version=4.0.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore <dependency name> major version` will close this
group update PR and stop Dependabot creating any more for the specific
dependency's major version (unless you unignore this specific
dependency's major version or upgrade to it yourself)
- `@dependabot ignore <dependency name> minor version` will close this
group update PR and stop Dependabot creating any more for the specific
dependency's minor version (unless you unignore this specific
dependency's minor version or upgrade to it yourself)
- `@dependabot ignore <dependency name>` will close this group update PR
and stop Dependabot creating any more for the specific dependency
(unless you unignore this specific dependency or upgrade to it yourself)
- `@dependabot unignore <dependency name>` will remove all of the ignore
conditions of the specified dependency
- `@dependabot unignore <dependency name> <ignore condition>` will
remove the ignore condition of the specified dependency and ignore
conditions


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-24 15:06:13 +11:00
dependabot[bot] 796190d435 chore: bump github.com/gohugoio/hugo from 0.157.0 to 0.158.0 (#23432)
Bumps [github.com/gohugoio/hugo](https://github.com/gohugoio/hugo) from
0.157.0 to 0.158.0.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/gohugoio/hugo/releases">github.com/gohugoio/hugo's
releases</a>.</em></p>
<blockquote>
<h2>v0.158.0</h2>
<p>This release adds <a
href="https://gohugo.io/functions/css/build/">css.Build</a>, native and
very fast bundling/transformation/minifying of CSS resources. Also see
the new <a
href="https://gohugo.io/functions/strings/replacepairs/">strings.ReplacePairs</a>,
a very fast option if you need to do many string replacements.</p>
<h2>Notes</h2>
<ul>
<li>Upgrade to to Go 1.26.1 (<a
href="https://redirect.github.com/gohugoio/hugo/issues/14597">#14597</a>)
(note) 1f578f16 <a href="https://github.com/bep"><code>@​bep</code></a>
<a
href="https://redirect.github.com/gohugoio/hugo/issues/14595">#14595</a>.
This fixes a security issue in Go's template package used by Hugo: <a
href="https://www.cve.org/CVERecord?id=CVE-2026-27142">https://www.cve.org/CVERecord?id=CVE-2026-27142</a></li>
</ul>
<h2>Deprecations</h2>
<p>The methods and config options are deprecated and will be removed in
a future Hugo release.</p>
<p>Also see <a
href="https://discourse.gohugo.io/t/deprecations-in-v0-158-0/56869">this
article</a></p>
<h3>Language configuration</h3>
<ul>
<li><code>languageCode</code> → Use <code>locale</code> instead.</li>
<li><code>languages.&lt;lang&gt;.languageCode</code> → Use
<code>languages.&lt;lang&gt;.locale</code> instead.</li>
<li><code>languages.&lt;lang&gt;.languageName</code> → Use
<code>languages.&lt;lang&gt;.label</code> instead.</li>
<li><code>languages.&lt;lang&gt;.languageDirection</code> → Use
<code>languages.&lt;lang&gt;.direction</code> instead.</li>
</ul>
<h3>Language methods</h3>
<ul>
<li><code>.Site.LanguageCode</code> → Use
<code>.Site.Language.Locale</code> instead.</li>
<li><code>.Language.LanguageCode</code> → Use
<code>.Language.Locale</code> instead.</li>
<li><code>.Language.LanguageName</code> → Use
<code>.Language.Label</code> instead.</li>
<li><code>.Language.LanguageDirection</code> → Use
<code>.Language.Direction</code> instead.</li>
</ul>
<h2>Bug fixes</h2>
<ul>
<li>tpl/css: Fix external source maps e431f90b <a
href="https://github.com/bep"><code>@​bep</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14620">#14620</a></li>
<li>hugolib: Fix server no watch 59e0446f <a
href="https://github.com/jmooring"><code>@​jmooring</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14615">#14615</a></li>
<li>resources: Fix context canceled on GetRemote with per-request
timeout 842d8f10 <a href="https://github.com/bep"><code>@​bep</code></a>
<a
href="https://redirect.github.com/gohugoio/hugo/issues/14611">#14611</a></li>
<li>tpl/tplimpl: Prefer early suffixes when media type matches 4eafd9eb
<a href="https://github.com/bep"><code>@​bep</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/13877">#13877</a>
<a
href="https://redirect.github.com/gohugoio/hugo/issues/14601">#14601</a></li>
<li>all: Run go fix ./... e3108225 <a
href="https://github.com/bep"><code>@​bep</code></a></li>
<li>internal/warpc: Fix SIGSEGV in Close() when dispatcher fails to
start c9b88e4d <a href="https://github.com/bep"><code>@​bep</code></a>
<a
href="https://redirect.github.com/gohugoio/hugo/issues/14536">#14536</a></li>
<li>Fix index out of range panic in fileEventsContentPaths f797f849 <a
href="https://github.com/bep"><code>@​bep</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14573">#14573</a></li>
</ul>
<h2>Improvements</h2>
<ul>
<li>resources: Re-publish on transformation cache hit 3c980c07 <a
href="https://github.com/bep"><code>@​bep</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14629">#14629</a></li>
<li>create/skeletons: Use css.Build in theme skeleton 404ac000 <a
href="https://github.com/jmooring"><code>@​jmooring</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14626">#14626</a></li>
<li>tpl/css: Add a test case for rebuilds on CSS options changes
06fcb724 <a href="https://github.com/bep"><code>@​bep</code></a></li>
<li>hugolib: Allow regular pages to cascade to self 9b5f1d49 <a
href="https://github.com/jmooring"><code>@​jmooring</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14627">#14627</a></li>
<li>tpl/css: Allow the user to override single loader entries 623722bb
<a href="https://github.com/bep"><code>@​bep</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14623">#14623</a></li>
<li>tpl/css: Make default loader resolution for CSS <a
href="https://github.com/import"><code>@​import</code></a> and url()
always behave the same a7cbcf15 <a
href="https://github.com/bep"><code>@​bep</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14619">#14619</a></li>
<li>internal/js: Add default mainFields for CSS builds 36cdb2c7 <a
href="https://github.com/jmooring"><code>@​jmooring</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14614">#14614</a></li>
<li>Add css.Build 3e3b849c <a
href="https://github.com/bep"><code>@​bep</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14609">#14609</a>
<a
href="https://redirect.github.com/gohugoio/hugo/issues/14613">#14613</a></li>
<li>resources: Use full path for Exif etc. decoding error/warning
messages c47ec233 <a
href="https://github.com/bep"><code>@​bep</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/12693">#12693</a></li>
<li>Move to new locales library and upgrade CLDR from v36.1 to v48.1
4652ae4a <a href="https://github.com/bep"><code>@​bep</code></a></li>
<li>tpl/strings: Add strings.ReplacePairs function 13a95b9c <a
href="https://github.com/jmooring"><code>@​jmooring</code></a> <a
href="https://redirect.github.com/gohugoio/hugo/issues/14594">#14594</a></li>
</ul>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/gohugoio/hugo/commit/f41be7959a44108641f1e081adf5c4be7fc1bb63"><code>f41be79</code></a>
releaser: Bump versions for release of 0.158.0</li>
<li><a
href="https://github.com/gohugoio/hugo/commit/0e46a97e8a0d5b7ad1dbea1a39dace7a3ee29fcf"><code>0e46a97</code></a>
deps: Upgrade github.com/evanw/esbuild v0.27.3 =&gt; v0.27.4</li>
<li><a
href="https://github.com/gohugoio/hugo/commit/c27d9e8fcfa5aad6cfedd0552add2a6c8ec74525"><code>c27d9e8</code></a>
build(deps): bump github.com/getkin/kin-openapi from 0.133.0 to
0.134.0</li>
<li><a
href="https://github.com/gohugoio/hugo/commit/098eac59a9d4f4567acb16018453c0d389677690"><code>098eac5</code></a>
build(deps): bump golang.org/x/tools from 0.42.0 to 0.43.0</li>
<li><a
href="https://github.com/gohugoio/hugo/commit/3c980c072ee6a9c37a1c6028a7d328696f745836"><code>3c980c0</code></a>
resources: Re-publish on transformation cache hit</li>
<li><a
href="https://github.com/gohugoio/hugo/commit/404ac00001de49c0ccbff4131be40fa2651e4a06"><code>404ac00</code></a>
create/skeletons: Use css.Build in theme skeleton</li>
<li><a
href="https://github.com/gohugoio/hugo/commit/06fcb724219eecdc20367e86e1a8134d3d7e0e5b"><code>06fcb72</code></a>
tpl/css: Add a test case for rebuilds on CSS options changes</li>
<li><a
href="https://github.com/gohugoio/hugo/commit/9b5f1d491d2b7cde198dd2fd858de92e9e97700f"><code>9b5f1d4</code></a>
hugolib: Allow regular pages to cascade to self</li>
<li><a
href="https://github.com/gohugoio/hugo/commit/87f8de8c7ab10516614180080f97490645bbfdec"><code>87f8de8</code></a>
build(deps): bump gocloud.dev from 0.44.0 to 0.45.0</li>
<li><a
href="https://github.com/gohugoio/hugo/commit/67ef6c68deb031f2dcff926b0cc236a07dcca334"><code>67ef6c6</code></a>
build(deps): bump golang.org/x/sync from 0.19.0 to 0.20.0</li>
<li>Additional commits viewable in <a
href="https://github.com/gohugoio/hugo/compare/v0.157.0...v0.158.0">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/gohugoio/hugo&package-manager=go_modules&previous-version=0.157.0&new-version=0.158.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-24 03:59:55 +00:00
Ethan c1474c7ee2 fix(coderd/httpmw): return 500 for internal auth errors (#23352)
## Issue context
On `dev.coder.com`, users could successfully log in, briefly see the web
UI, and then get redirected back to `/login`.

We traced the most reliable repro to viewing Tracy's workspaces on the
`/workspaces` page. That page eagerly issues authenticated per-row
requests such as:
- `POST /api/v2/authcheck`
- `GET /api/v2/workspacebuilds/:workspacebuild/parameters`

One confirmed failing request was for Tracy's workspace
`nav-scroll-fix-1f6b`:
- route: `GET
/api/v2/workspacebuilds/f2104ae6-7d53-457c-a8df-de831bee76db/parameters`
- build owner/workspace: `tracy/nav-scroll-fix-1f6b`

The failing response body was:
- message: `An internal error occurred. Please try again or contact the
system administrator.`
- detail: `Internal error fetching API key by id. fetch object: pq:
password authentication failed for user "coder"`

That showed the request was not actually unauthorized. The server hit an
internal database/authentication problem while resolving the session API
key. The underlying issue was that DB password rotation had been
enabled, it has since been disabled.

However, the logout cascade happened because:
1. `APIKeyFromRequest()` returned `ok=false` for both genuine auth
failures and internal backend failures.
2. `ValidateAPIKey()` wrapped every `!ok` result as `401 Unauthorized`.
3. `RequireAuth.tsx` signs the user out on any `401` response.

So a transient backend/database failure was being misreported as an auth
failure, which made the client forcibly log the user out.

A useful extra clue was that the installed PWA did not repro. The PWA
starts on `/agents`, which avoids the `/workspaces` request fan-out.
That helped narrow the problem to the eager authenticated requests on
the workspace list rather than to cookies or the login flow itself.

## What changed
This PR now fixes the bug without changing the exported
`APIKeyFromRequest()` surface:
- `ValidateAPIKey()` now uses a new internal helper that returns a typed
`ValidateAPIKeyError`
- the exported `APIKeyFromRequest()` helper remains compatible for
existing callers like `userauth.go`
- internal API-key lookup failures are classified as `500 Internal
Server Error` plus `Hard: true`
- internal `UserRBACSubject()` failures now return `500 Internal Server
Error` instead of `401 Unauthorized`
- a focused regression test verifies that an internal `GetAPIKeyByID`
failure surfaces as `500`

This removes the brittle message-based classification and makes the
internal-auth-failure path robust for all API-key lookup failures
handled by auth middleware.
2026-03-24 12:37:17 +11:00
Danielle Maywood a8e7cc10b6 fix(site): isolate draft prompts per conversation (#23469) 2026-03-24 01:05:19 +00:00
Michael Suchacz 82f965a0ae feat: per-user per-model chat compaction threshold overrides (#23412)
## What

Adds per-user per-model auto-compaction threshold overrides. Users can
now customize the percentage of context window usage that triggers chat
compaction, independently for each enabled model.

## Why

The compaction threshold was previously only configurable at the
deployment level (`chat_model_configs.compression_threshold`). Different
users have different preferences — some want aggressive compaction to
keep costs low, others prefer higher thresholds to retain more context.
This gives users control without requiring admin intervention.

## Architecture

**Storage:** Reuses the existing `user_configs` table (no migration
needed). Overrides are stored as key/value pairs with keys shaped
`chat_compaction_threshold:<modelConfigID>` and integer percent values.

**API:** Three new experimental endpoints under
`/api/experimental/chats/config/`:
- `GET /user-compaction-thresholds` — list all overrides for the current
user
- `PUT /user-compaction-thresholds/{modelConfig}` — upsert an override
(validates model exists and is enabled, validates 0–100 range)
- `DELETE /user-compaction-thresholds/{modelConfig}` — clear an override
(idempotent)

**Runtime resolution:** In `coderd/chatd/chatd.go`, a new
`resolveUserCompactionThreshold()` helper runs at the start of each chat
turn (inside `runChat()`), after the model config is resolved but before
`CompactionOptions` is built. If a valid override exists, it replaces
`modelConfig.CompressionThreshold`. The threshold source
(`user_override` vs `model_default`) is logged with each compaction
event.

**Precedence:** `effectiveThreshold = userOverride ??
modelConfig.CompressionThreshold`

**UI:** New "Context Compaction" subsection in the Agents → Settings →
Behavior tab, placed after Personal Instructions. Shows one row per
enabled model with the system default, a number input for the override,
and Save/Reset controls.

## Testing

- 9 API subtests covering CRUD, validation (boundary values 0/100,
out-of-range rejection), upsert behavior, idempotent delete, user
isolation, and non-existent model config
- 4 dbauthz tests (16 scenarios) verifying `ActionReadPersonal` /
`ActionUpdatePersonal` on all query methods
- 4 Storybook stories with play functions (Default, WithOverrides,
Loading, Error)

<details>
<summary>Implementation plan</summary>

### Phase 1 — Tests
- Backend API tests in `coderd/chats_test.go` (9 subtests)
- Database auth wrapper tests in
`coderd/database/dbauthz/dbauthz_test.go` (4 methods)
- Frontend stories in `UserCompactionThresholdSettings.stories.tsx` (4
stories)

### Phase 2 — Backend preference surface
- 4 SQL queries in `coderd/database/queries/users.sql` (list, get,
upsert, delete)
- `make gen` to propagate into generated artifacts
- Auth/metrics wrappers in dbauthz and dbmetrics
- SDK types and client methods in `codersdk/chats.go`
- HTTP handlers and routes in `coderd/chats.go` and `coderd/coderd.go`
- Key prefix constant shared between handlers and runtime

### Phase 3 — Runtime override
- `resolveUserCompactionThreshold()` helper in `coderd/chatd/chatd.go`
- Override injection in `runChat()` before building `CompactionOptions`
- `threshold_source` field added to compaction log

### Phase 4 — Settings UI
- API client methods and React Query hooks in `site/src/api/`
- `UserCompactionThresholdSettings` component extracted from
`SettingsPageContent`
- Per-model mutation tracking (only the active row disables during save)
- 100% warning, "System default" label, helpful empty state copy

### Phase 5 — Refactor and review fixes
- Consolidated key prefix constant in `codersdk`
- Explicit PUT range validation (not just struct tags)
- GET handler gracefully skips malformed rows instead of 500
- Boundary value, upsert, and non-existent model config tests
- UX improvements: per-model mutation state, aria-live on errors

</details>
2026-03-24 00:48:18 +01:00
Kyle Carberry acbfb90c30 feat: auto-discover OAuth2 config for MCP servers via RFC 7591 DCR (#23406)
## Problem

When adding an external MCP server with `auth_type=oauth2`, admins
currently must manually provide:
- `oauth2_client_id`
- `oauth2_client_secret`
- `oauth2_auth_url`
- `oauth2_token_url`

This requires the admin to manually register an OAuth2 client with the
external MCP server's authorization server first — a friction-heavy
process that contradicts the MCP spec's vision of plug-and-play
discovery.

## Solution

When an admin creates an MCP server config with `auth_type=oauth2` and
omits the OAuth2 fields, Coder now automatically discovers and registers
credentials following the MCP authorization spec:

1. **Protected Resource Metadata (RFC 9728)** — Fetches
`/.well-known/oauth-protected-resource` from the MCP server to discover
its authorization server. Falls back to probing the server URL for a
`WWW-Authenticate` header with a `resource_metadata` parameter.

2. **Authorization Server Metadata (RFC 8414)** — Fetches
`/.well-known/oauth-authorization-server` from the discovered auth
server to find all endpoints.

3. **Dynamic Client Registration (RFC 7591)** — Registers Coder as an
OAuth2 client at the auth server's registration endpoint, obtaining a
`client_id` and `client_secret` automatically.

The discovered/generated credentials are stored in the MCP server
config, and the existing per-user OAuth2 connect flow works unchanged.

### Backward compatibility

- **Manual config still works**: If all three fields
(`oauth2_client_id`, `oauth2_auth_url`, `oauth2_token_url`) are
provided, the existing behavior is unchanged.
- **Partial config is rejected**: Providing some but not all fields
returns a clear error explaining the two options.
- **Discovery failure is clear**: If auto-discovery fails, the error
message explains what went wrong and suggests manual configuration.

## Changes

- **New package `coderd/mcpauth`** — Self-contained discovery and DCR
logic with no `codersdk` dependency
- **Modified `coderd/mcp.go`** — `createMCPServerConfig` handler now
attempts auto-discovery when OAuth2 fields are omitted
- **Tests** — Unit tests for discovery (happy path, WWW-Authenticate
fallback, no registration endpoint, registration failure) and
`parseResourceMetadataParam` helper
2026-03-23 19:26:47 -04:00
Danielle Maywood c344d7c00e fix(site): improve mobile layout for settings and analytics (#23460) 2026-03-23 22:00:23 +00:00
david-fraley 53350377b3 docs: add Agents Getting Started enablement page (#23244) 2026-03-23 16:56:46 -05:00
Mathias Fredriksson 147df5c971 refactor: replace sort.Strings with slices.Sort (#23457)
The slices package provides type-safe generic replacements for the
old typed sort convenience functions. The codebase already uses
slices.Sort in 43 call sites; this finishes the migration for the
remaining 29.

- sort.Strings(x)          -> slices.Sort(x)
- sort.Float64s(x)         -> slices.Sort(x)
- sort.StringsAreSorted(x) -> slices.IsSorted(x)
2026-03-23 23:19:23 +02:00
Cian Johnston 9e4c283370 test: share coderdtest instances in OAuth2 validation tests (#23455)
Consolidates invocations of `coderdtest.New` to a single shared instance per
parent for the following tests:

- `TestOAuth2ClientMetadataValidation`
- `TestOAuth2ClientNameValidation`
- `TestOAuth2ClientScopeValidation`
- `TestOAuth2ClientMetadataEdgeCases`

> 🤖 This PR was created with the help of Coder Agents, and was
reviewed by my human. 🧑‍💻
2026-03-23 21:03:34 +00:00
Mathias Fredriksson 145817e8d3 fix(Makefile): install playwright browsers before storybook tests (#23456)
The test-storybook target uses @vitest/browser-playwright with
Chromium but never installs the browser binaries. pnpm install
only fetches the npm package; the actual browser must be
downloaded separately via playwright install. This mirrors what
test-e2e already does.
2026-03-23 20:57:03 +00:00
Cian Johnston 956f6b2473 test: share coderdtest instances to stop paying the startup tax 22 times (#23454)
Consolidates 6 tests that spun up separate coderdtest instances per sub-test into a single shared instance per parent. 

> 🤖 This PR was created with the help of Coder Agents, and has been
reviewed by my human. 🧑‍💻
2026-03-23 19:54:43 +00:00
Kayla はな d2afda8191 feat: allow restricting sharing to service accounts (#23327) 2026-03-23 13:18:49 -06:00
Michael Suchacz c389c2bc5c fix(coderd/x/chatd): stabilize auto-promotion flake (#23448)
TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease still relied on
wall-clock polling after the acquire loop moved to a mock clock, so it
could assert before chatd finished its asynchronous cleanup and
auto-promotion work.

Wait on explicit request-start signals and on the server's in-flight
chat work before asserting the intermediate and final database state.
This keeps the test synchronized with the actual processor lifecycle
instead of scheduler timing.

Closes https://github.com/coder/internal/issues/1406
2026-03-23 19:17:58 +00:00
Kayla はな 4c9e37b659 feat: add page for editing users (#23328) 2026-03-23 12:42:50 -06:00
Cian Johnston 3b268c95d3 chore(dogfood): evict 22 freeloading tools from the Dockerfile (#23378)
Removes unused tools from dogfood Dockerfile:
- Go tools `moq`, `go-swagger`, `goreleaser`, `goveralls`, `kind`,
`helm-docs`, `gcr-cleaner-cli`
- curl-installed `cloud_sql_proxy`, `dive`, `docker-credential-gcr`, `grype`,
`kube-linter`, `stripe` CLI, `terragrunt`, `yq` v3, GoLand 2021.2 , ANTLR v4 jar
- apt packages `cmake`, `google-cloud-sdk-datastore-emulator`, `graphviz`, `packer`

> 🤖 This PR was created with the help of Coder Agents, and was reviewed by my human. 🧑‍💻
2026-03-23 18:25:58 +00:00
Mathias Fredriksson 138bc41563 fix: improve process tool descriptions to prefer foreground execution (#23395)
The tool descriptions pushed agents toward backgrounding anything over
5 seconds, including builds, tests, and installs where you actually
want to wait for the result. This led to unnecessary process_output
round-trips and missed the foreground timeout-to-reattach workflow
entirely.

Reframe background mode as the exception (persistent processes with
no natural exit) and foreground with an appropriate timeout as the
default. Replace "background process" with "tracked process" in
process_output, process_list, and process_signal since they work on
all tracked processes regardless of how they were started.
2026-03-23 17:54:30 +00:00
Cian Johnston 80a172f932 chore: move chatd and related packages to /x/ subpackage (#23445)
- Moves `coderd/chatd/`, `coderd/gitsync/`, `enterprise/coderd/chatd/`
under `x/` parent directories to signal instability
- Adds `Experimental:` glue code comments in `coderd/coderd.go`

> 🤖 This PR was created with the help of Coder Agents, and was
reviewed by my human. 🧑‍💻
2026-03-23 17:34:43 +00:00
Danielle Maywood 86d8b6daee fix(site/src/pages/AgentsPage): add collapse button to settings sidebar panel (#23438) 2026-03-23 17:22:08 +00:00
Danielle Maywood 470e6c7217 feat(site): enable intra-file virtualization in DiffViewer (#23363) 2026-03-23 16:37:55 +00:00
Danielle Maywood ed19a3a08e refactor(site): move experimental endpoints to ExperimentalApiMethods (#23449) 2026-03-23 16:29:07 +00:00
Danielle Maywood 975373704f fix(site): unify diff header styling between conversation and panel viewers (#23422) 2026-03-23 16:21:53 +00:00
Danielle Maywood 522288c9d5 fix(site): add chat input skeleton to prevent layout shift on agent detail (#23439) 2026-03-23 14:41:09 +00:00
Danielle Maywood edd13482a0 fix(site): focus chat input after submitting diff comment (#23440) 2026-03-23 14:40:10 +00:00
Cian Johnston ef14654078 chore: move chat methods to ExperimentalClient (#23441)
- Changes all 41 chat method receivers in `codersdk/chats.go` from
`*Client` to `*ExperimentalClient` to ensure that callers are aware that
these reference potentially unstable `/api/experimental` endpoints.


> 🤖 This PR was created with the help of Coder Agents, and has been
reviewed by my human. 🧑‍💻
2026-03-23 14:32:11 +00:00
Thomas Kosiewski ea37f1ff86 feat: pass session token as query param on agent chat WebSockets (#23405)
## Problem

When the Coder chat UI is embedded in a VS Code webview, the session
token is set via the Coder-Session-Token header for HTTP requests.
However, browsers cannot attach custom headers to WebSocket connections,
and VS Code Electron webview environment does not support cookies set
via Set-Cookie from iframe origins. This causes all chat WebSocket
connections to fail with authorization errors.

## Solution

Pass the session token as a coder_session_token query parameter on all
chat-related WebSocket connections. The backend already accepts this
parameter (see APITokenFromRequest in coderd/httpmw/apikey.go).

The token is only included when API.getSessionToken() returns a value,
which only happens in the embed bootstrap flow. Normal browser sessions
use cookies and are unaffected.

> Built with [Coder Agents](https://coder.com/agents)
2026-03-23 15:27:55 +01:00
Mathias Fredriksson c49170b6b3 fix(scaletest): handle ignored io.ReadAll error in bridge runner (#22850)
Surface the io.ReadAll error in the error message when an HTTP
request fails with a non-200 status, instead of silently
discarding it.
2026-03-23 15:58:14 +02:00
Danielle Maywood ee9b46fe08 fix(site/src/pages/AgentsPage): replace navigating buttons with anchor tags (#23426) 2026-03-23 12:20:56 +00:00
Mathias Fredriksson 1ad3c898a0 fix(coderd/chatd): preserve identifiers in chat title generation (#23436)
The prompt told the model to "describe the primary intent" and gave
only generic examples, so it stripped PR numbers, repo names, and
other distinguishing details. Added explicit GOOD/BAD examples to
steer away from generic titles like "Review pull request changes".
Also removed "no special characters" which prevented # and / in
identifiers.
2026-03-23 12:02:05 +00:00
Jakub Domeracki b8e09d09b0 chore: remove trivy GHA job (#23415)
Action taken In response to an ongoing incident:

https://www.aquasec.com/blog/trivy-supply-chain-attack-what-you-need-to-know/

> We've not been compromised due to a combination of pinning [GitHub
Actions by commit
SHA](https://github.com/coder/coder/blob/c8e58575e0ee44fad37b5f2ffe1ef0f220c3cf23/.github/workflows/security.yaml#L149)
coupled with a [dependabot cooldown
period](https://github.com/coder/coder/pull/21079)
2026-03-23 12:52:28 +01:00
dependabot[bot] 0900a44ff3 chore: bump github.com/fatih/color from 1.18.0 to 1.19.0 (#23431)
Bumps [github.com/fatih/color](https://github.com/fatih/color) from
1.18.0 to 1.19.0.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/fatih/color/releases">github.com/fatih/color's
releases</a>.</em></p>
<blockquote>
<h2>v1.19.0</h2>
<h2>What's Changed</h2>
<ul>
<li>Bump golang.org/x/sys from 0.25.0 to 0.28.0 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/fatih/color/pull/246">fatih/color#246</a></li>
<li>Fix for issue <a
href="https://redirect.github.com/fatih/color/issues/230">#230</a>
set/unsetwriter symmetric wrt color support detection by <a
href="https://github.com/ataypamart"><code>@​ataypamart</code></a> in <a
href="https://redirect.github.com/fatih/color/pull/243">fatih/color#243</a></li>
<li>chore: go mod cleanup by <a
href="https://github.com/sashamelentyev"><code>@​sashamelentyev</code></a>
in <a
href="https://redirect.github.com/fatih/color/pull/244">fatih/color#244</a></li>
<li>Bump golang.org/x/sys from 0.28.0 to 0.30.0 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/fatih/color/pull/249">fatih/color#249</a></li>
<li>Bump github.com/mattn/go-colorable from 0.1.13 to 0.1.14 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/fatih/color/pull/248">fatih/color#248</a></li>
<li>Update CI and go deps by <a
href="https://github.com/fatih"><code>@​fatih</code></a> in <a
href="https://redirect.github.com/fatih/color/pull/254">fatih/color#254</a></li>
<li>Bump golang.org/x/sys from 0.31.0 to 0.37.0 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/fatih/color/pull/268">fatih/color#268</a></li>
<li>fix: include escape codes in byte counts from <code>Fprint</code>,
<code>Fprintf</code> by <a
href="https://github.com/qualidafial"><code>@​qualidafial</code></a> in
<a
href="https://redirect.github.com/fatih/color/pull/282">fatih/color#282</a></li>
<li>Bump golang.org/x/sys from 0.37.0 to 0.40.0 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/fatih/color/pull/277">fatih/color#277</a></li>
<li>fix: add nil check for os.Stdout to prevent panic on Windows
services by <a
href="https://github.com/majiayu000"><code>@​majiayu000</code></a> in <a
href="https://redirect.github.com/fatih/color/pull/275">fatih/color#275</a></li>
<li>Bump dominikh/staticcheck-action from 1.3.1 to 1.4.0 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/fatih/color/pull/259">fatih/color#259</a></li>
<li>Bump actions/checkout from 4 to 6 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/fatih/color/pull/273">fatih/color#273</a></li>
<li>Optimize Color.Equals performance (O(n²) → O(n)) by <a
href="https://github.com/UnSubble"><code>@​UnSubble</code></a> in <a
href="https://redirect.github.com/fatih/color/pull/269">fatih/color#269</a></li>
<li>Bump actions/setup-go from 5 to 6 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a>[bot]
in <a
href="https://redirect.github.com/fatih/color/pull/266">fatih/color#266</a></li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a
href="https://github.com/ataypamart"><code>@​ataypamart</code></a> made
their first contribution in <a
href="https://redirect.github.com/fatih/color/pull/243">fatih/color#243</a></li>
<li><a
href="https://github.com/sashamelentyev"><code>@​sashamelentyev</code></a>
made their first contribution in <a
href="https://redirect.github.com/fatih/color/pull/244">fatih/color#244</a></li>
<li><a
href="https://github.com/qualidafial"><code>@​qualidafial</code></a>
made their first contribution in <a
href="https://redirect.github.com/fatih/color/pull/282">fatih/color#282</a></li>
<li><a
href="https://github.com/majiayu000"><code>@​majiayu000</code></a> made
their first contribution in <a
href="https://redirect.github.com/fatih/color/pull/275">fatih/color#275</a></li>
<li><a href="https://github.com/UnSubble"><code>@​UnSubble</code></a>
made their first contribution in <a
href="https://redirect.github.com/fatih/color/pull/269">fatih/color#269</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/fatih/color/compare/v1.18.0...v1.19.0">https://github.com/fatih/color/compare/v1.18.0...v1.19.0</a></p>
</blockquote>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/fatih/color/commit/ca25f6e17f118a5a259f3c2c0d395949d1103a5a"><code>ca25f6e</code></a>
Merge pull request <a
href="https://redirect.github.com/fatih/color/issues/266">#266</a> from
fatih/dependabot/github_actions/actions/setup-go-6</li>
<li><a
href="https://github.com/fatih/color/commit/120598440a16510564204450092d1e7925fad9ae"><code>1205984</code></a>
Bump actions/setup-go from 5 to 6</li>
<li><a
href="https://github.com/fatih/color/commit/5715c20323d8c79f60d4944831fcfa3b76cd5734"><code>5715c20</code></a>
Merge pull request <a
href="https://redirect.github.com/fatih/color/issues/269">#269</a> from
UnSubble/main</li>
<li><a
href="https://github.com/fatih/color/commit/2f6e2003760028129f34c4ad5c3728b904811d3c"><code>2f6e200</code></a>
Merge branch 'main' into main</li>
<li><a
href="https://github.com/fatih/color/commit/f72ec947d0c34504dfd08b0db68d89f37503fc90"><code>f72ec94</code></a>
Merge pull request <a
href="https://redirect.github.com/fatih/color/issues/273">#273</a> from
fatih/dependabot/github_actions/actions/checkout-6</li>
<li><a
href="https://github.com/fatih/color/commit/848e6330af5690fa24bb038d5330839a33f1f0e5"><code>848e633</code></a>
Merge branch 'main' into main</li>
<li><a
href="https://github.com/fatih/color/commit/4c2cd3443934693bd8892fc0f7bb5bbec8e3788a"><code>4c2cd34</code></a>
Add tests</li>
<li><a
href="https://github.com/fatih/color/commit/7f812f029c41eddd3ac7fbbdf6cc78e4b175944b"><code>7f812f0</code></a>
Bump actions/checkout from 4 to 6</li>
<li><a
href="https://github.com/fatih/color/commit/b7fc9f9557629556aff702751b5268cefcbafa15"><code>b7fc9f9</code></a>
Merge pull request <a
href="https://redirect.github.com/fatih/color/issues/259">#259</a> from
fatih/dependabot/github_actions/dominikh/staticc...</li>
<li><a
href="https://github.com/fatih/color/commit/239a88f715e8e35f40492da7a1e08f7173e78e05"><code>239a88f</code></a>
Bump dominikh/staticcheck-action from 1.3.1 to 1.4.0</li>
<li>Additional commits viewable in <a
href="https://github.com/fatih/color/compare/v1.18.0...v1.19.0">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/fatih/color&package-manager=go_modules&previous-version=1.18.0&new-version=1.19.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 11:41:47 +00:00
dependabot[bot] 4537413315 chore: bump google.golang.org/api from 0.271.0 to 0.272.0 (#23430)
Bumps
[google.golang.org/api](https://github.com/googleapis/google-api-go-client)
from 0.271.0 to 0.272.0.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/googleapis/google-api-go-client/releases">google.golang.org/api's
releases</a>.</em></p>
<blockquote>
<h2>v0.272.0</h2>
<h2><a
href="https://github.com/googleapis/google-api-go-client/compare/v0.271.0...v0.272.0">0.272.0</a>
(2026-03-16)</h2>
<h3>Features</h3>
<ul>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3534">#3534</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/b4d37a1279665d52b8b4672a6a91732ae8eb3cf6">b4d37a1</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3536">#3536</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/549ef3e69575edbe4fee27bc485a093dc88b90b3">549ef3e</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3537">#3537</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/6def284013185ab4ac2fa389594ee6013086d5d0">6def284</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3538">#3538</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/319b5abcbc42b77f6acc861e45365b65695e8096">319b5ab</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3539">#3539</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/73bcfcf9b2fd8def3aec1cdff10e6d4ee646af41">73bcfcf</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3541">#3541</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/6374c496fde577aa9f5b32470e45676ff4f69dde">6374c49</a>)</li>
</ul>
</blockquote>
</details>
<details>
<summary>Changelog</summary>
<p><em>Sourced from <a
href="https://github.com/googleapis/google-api-go-client/blob/main/CHANGES.md">google.golang.org/api's
changelog</a>.</em></p>
<blockquote>
<h2><a
href="https://github.com/googleapis/google-api-go-client/compare/v0.271.0...v0.272.0">0.272.0</a>
(2026-03-16)</h2>
<h3>Features</h3>
<ul>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3534">#3534</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/b4d37a1279665d52b8b4672a6a91732ae8eb3cf6">b4d37a1</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3536">#3536</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/549ef3e69575edbe4fee27bc485a093dc88b90b3">549ef3e</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3537">#3537</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/6def284013185ab4ac2fa389594ee6013086d5d0">6def284</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3538">#3538</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/319b5abcbc42b77f6acc861e45365b65695e8096">319b5ab</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3539">#3539</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/73bcfcf9b2fd8def3aec1cdff10e6d4ee646af41">73bcfcf</a>)</li>
<li><strong>all:</strong> Auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3541">#3541</a>)
(<a
href="https://github.com/googleapis/google-api-go-client/commit/6374c496fde577aa9f5b32470e45676ff4f69dde">6374c49</a>)</li>
</ul>
</blockquote>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/googleapis/google-api-go-client/commit/e7df9fe0b92461f87b6d267a600e6825d1221e75"><code>e7df9fe</code></a>
chore(main): release 0.272.0 (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3535">#3535</a>)</li>
<li><a
href="https://github.com/googleapis/google-api-go-client/commit/5d8b2662ac4cd19ac978d9f08bedb59dc41c8247"><code>5d8b266</code></a>
chore(all): update all (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3540">#3540</a>)</li>
<li><a
href="https://github.com/googleapis/google-api-go-client/commit/6374c496fde577aa9f5b32470e45676ff4f69dde"><code>6374c49</code></a>
feat(all): auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3541">#3541</a>)</li>
<li><a
href="https://github.com/googleapis/google-api-go-client/commit/73bcfcf9b2fd8def3aec1cdff10e6d4ee646af41"><code>73bcfcf</code></a>
feat(all): auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3539">#3539</a>)</li>
<li><a
href="https://github.com/googleapis/google-api-go-client/commit/319b5abcbc42b77f6acc861e45365b65695e8096"><code>319b5ab</code></a>
feat(all): auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3538">#3538</a>)</li>
<li><a
href="https://github.com/googleapis/google-api-go-client/commit/6def284013185ab4ac2fa389594ee6013086d5d0"><code>6def284</code></a>
feat(all): auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3537">#3537</a>)</li>
<li><a
href="https://github.com/googleapis/google-api-go-client/commit/549ef3e69575edbe4fee27bc485a093dc88b90b3"><code>549ef3e</code></a>
feat(all): auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3536">#3536</a>)</li>
<li><a
href="https://github.com/googleapis/google-api-go-client/commit/b4d37a1279665d52b8b4672a6a91732ae8eb3cf6"><code>b4d37a1</code></a>
feat(all): auto-regenerate discovery clients (<a
href="https://redirect.github.com/googleapis/google-api-go-client/issues/3534">#3534</a>)</li>
<li>See full diff in <a
href="https://github.com/googleapis/google-api-go-client/compare/v0.271.0...v0.272.0">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=google.golang.org/api&package-manager=go_modules&previous-version=0.271.0&new-version=0.272.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 11:33:26 +00:00
Cian Johnston ab86ed0df8 fix(site): stop hijacking navigation after archive-and-delete settles (#23372)
- Guard both `onSettled` callbacks in
`archiveAndDeleteMutation.mutate()` with `shouldNavigateAfterArchive()`,
which checks whether the user is still viewing the archived chat (or a
sub-agent of it) before calling `navigate("/agents")`
- Extract `shouldNavigateAfterArchive` into `agentWorkspaceUtils.ts`
with 6 unit test cases covering: direct match, different chat, no active
chat, sub-agent of archived parent, sub-agent of different parent, and
cache-cleared fallback
- Look up the active chat's `root_chat_id` from the per-chat query cache
(stable across WebSocket eviction of sub-agents) to handle the sub-agent
case

> 🤖 This PR was created with the help of Coder Agents, and has been
reviewed by my human. 🧑‍💻
2026-03-23 11:28:06 +00:00
dependabot[bot] f2b9d5f8f7 chore: bump github.com/fergusstrange/embedded-postgres from 1.32.0 to 1.34.0 (#23428)
Bumps
[github.com/fergusstrange/embedded-postgres](https://github.com/fergusstrange/embedded-postgres)
from 1.32.0 to 1.34.0.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/fergusstrange/embedded-postgres/releases">github.com/fergusstrange/embedded-postgres's
releases</a>.</em></p>
<blockquote>
<h2>v1.34.0</h2>
<h2>What's Changed</h2>
<ul>
<li>Bump V18 from 18.0.0 to 18.3.0 to fix darwin/arm64 by <a
href="https://github.com/nzoschke"><code>@​nzoschke</code></a> in <a
href="https://redirect.github.com/fergusstrange/embedded-postgres/pull/166">fergusstrange/embedded-postgres#166</a></li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a href="https://github.com/nzoschke"><code>@​nzoschke</code></a>
made their first contribution in <a
href="https://redirect.github.com/fergusstrange/embedded-postgres/pull/166">fergusstrange/embedded-postgres#166</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/fergusstrange/embedded-postgres/compare/v1.33.0...v1.34.0">https://github.com/fergusstrange/embedded-postgres/compare/v1.33.0...v1.34.0</a></p>
<h2>v1.33.0</h2>
<h2>What's Changed</h2>
<ul>
<li>Add support for Postgres 18 and update default version by <a
href="https://github.com/otakakot"><code>@​otakakot</code></a> in <a
href="https://redirect.github.com/fergusstrange/embedded-postgres/pull/162">fergusstrange/embedded-postgres#162</a></li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a href="https://github.com/otakakot"><code>@​otakakot</code></a>
made their first contribution in <a
href="https://redirect.github.com/fergusstrange/embedded-postgres/pull/162">fergusstrange/embedded-postgres#162</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/fergusstrange/embedded-postgres/compare/v1.32.0...v1.33.0">https://github.com/fergusstrange/embedded-postgres/compare/v1.32.0...v1.33.0</a></p>
</blockquote>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/490777eebf4d3fe8615496cd4fc8430f5b93379d"><code>490777e</code></a>
Bump V18 from 18.0.0 to 18.3.0 to fix darwin/arm64 (<a
href="https://redirect.github.com/fergusstrange/embedded-postgres/issues/166">#166</a>)</li>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/f351010461d7666dff82b7bf88986d1e4d5824af"><code>f351010</code></a>
Update README.md</li>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/cf5b3570ca7fc727fae6e4874ec08b4818b705b1"><code>cf5b357</code></a>
Update CircleCI config: add Rosetta installation step for macOS
executor</li>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/a2782271984af1c658bc68ec5ead130968be4071"><code>a278227</code></a>
Update CircleCI config: specify Go version 1.18 for macOS executor</li>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/e96b8985a6cf932ee40a412ab8403dc13073420e"><code>e96b898</code></a>
Update CircleCI config: change Apple executor from m2 to m4</li>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/10719368a4343cc494f84db42b1a8a3199b6cc4f"><code>1071936</code></a>
Update CircleCI config: rename cache steps for Go modules</li>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/2bb06046c7b832f9bd54034f2a665b01f6f037b5"><code>2bb0604</code></a>
Update CircleCI config: modify macOS executor, upgrade xcode and go
orb</li>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/8b9ced41d43db993baf672c7a3ac308c9822d99c"><code>8b9ced4</code></a>
Add OSSI_TOKEN and OSSI_USERNAME to Nancy action environment</li>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/482d9032341eeede28e7f69637d3c0856721aae7"><code>482d903</code></a>
Bump Nancy Vulnerability Checker to v1.0.52</li>
<li><a
href="https://github.com/fergusstrange/embedded-postgres/commit/3578d6e73071963906311f846e6cf51470203bdc"><code>3578d6e</code></a>
Add support for Postgres 18 and update default version (<a
href="https://redirect.github.com/fergusstrange/embedded-postgres/issues/162">#162</a>)</li>
<li>See full diff in <a
href="https://github.com/fergusstrange/embedded-postgres/compare/v1.32.0...v1.34.0">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/fergusstrange/embedded-postgres&package-manager=go_modules&previous-version=1.32.0&new-version=1.34.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 11:27:49 +00:00
dependabot[bot] b73983e309 chore: bump ubuntu from 3ba65aa to ce4a593 in /dogfood/coder (#23434)
Bumps ubuntu from `3ba65aa` to `ce4a593`.


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ubuntu&package-manager=docker&previous-version=jammy&new-version=jammy)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 11:23:30 +00:00
dependabot[bot] c11cc0ba30 chore: bump rust from 7d37016 to f7bf1c2 in /dogfood/coder (#23433)
Bumps rust from `7d37016` to `f7bf1c2`.


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=rust&package-manager=docker&previous-version=slim&new-version=slim)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 11:23:14 +00:00
Hugo Dutka 3163e74b77 fix: bump agents desktop resolution to 1920x1080 (#23425)
This PR changes agents desktop resolution from 1366x768 to 1920x1080.
Anthropic requires the that the resolution of desktop screenshots fits
in 1,150,000 total pixels, so we downscale screenshots to 1280x720
before sending them to the LLM provider.

Resolution scaling was already implemented, but our code didn't exercise
it. The resolution bump showed that there were some bugs in the scaling
logic - this PR fixes these bugs too.
2026-03-23 11:51:10 +01:00
Danielle Maywood eca2257c26 fix(site): enable word-level inline diff highlighting in DiffViewer (#23423) 2026-03-23 10:30:38 +00:00
Mathias Fredriksson 75f5b60eb6 fix: return 409 Conflict instead of 502 when task agent is busy (#23424)
The "Task app is not ready to accept input" error occurs when the
agent responds successfully but its status is not "stable" (e.g.
"running"). This is a state conflict, not a gateway error. 502 was
semantically wrong because the gateway communication succeeded.

409 Conflict is correct because the request conflicts with the
agent's current state. This is consistent with how
authAndDoWithTaskAppClient already returns 409 for pending,
initializing, and paused agent states.
2026-03-23 09:52:34 +00:00
Ethan 69d430f51b fix(site): fix flaky UsageUserDrillIn story assertion (#23416)
## Problem

The `UsageUserDrillIn` play function in
`AgentSettingsPageView.stories.tsx`
flakes in Chromatic (noticed in #23282). After clicking a user row to
drill
into the detail view, sync assertions fire before React finishes the
state
transition — element not found.

<img width="1110" height="649" alt="image"
src="https://github.com/user-attachments/assets/8b5c36c2-09c4-4dd6-a280-ab6379c1464e"
/>


### Root cause

The play function clicks "Alice Liddell" and then waits with
`findByText("Alice Liddell")` before asserting on detail-view content.
But
"Alice Liddell" appears in **both** the list row and the detail header,
so
`findByText` resolves immediately against the stale list-row text that
is
still in the DOM. The same is true for `"@alice"` — `UserRow` renders
`@${user.username}` as a subtitle in the list, and `AvatarData` renders
it
again in the detail view.

### Fix

Gate on `"User ID: ..."` instead — text that **only** renders in the
detail
panel. Once it is in the DOM, the detail view is fully mounted and all
sync
assertions are safe.

Applied to both `UsageUserDrillIn` and `UsageUserDrillInAndBack`, which
had
the same issue.
2026-03-23 19:45:30 +11:00
Ethan 0f3d40b97f fix(site): stabilize date params to break infinite query loop on agents/analytics (#23414)
## Problem

`/agents/analytics` showed an infinite loading spinner. The browser
devtools revealed repeated requests to the chat cost summary endpoint
with `start_date` and `end_date` shifting by a few milliseconds on each
request.

`AgentAnalyticsPage` called `createDateRange(now)` on every render. When
`now` is not passed (production), `createDateRange` falls through to
`dayjs()`, which produces a new millisecond-precision timestamp each
time. Those timestamps became part of the React Query key via
`chatCostSummary()`, so every render created a new query identity, fired
a new fetch, state-updated, re-rendered, and the cycle repeated. The
page never left the loading branch because no query result was ever
observed for the `current` key before it changed.

The same pattern existed in `InsightsContent`, where
`timeRangeToDates()` called `dayjs()` on every render and fed the result
into `prInsights()`.

Storybook didn't catch this because stories pass a fixed `now` prop,
keeping the date range stable.

## Fix

Anchor the date window once using `useState`'s lazy initializer, then
derive `start_date`/`end_date` from the stable anchor during render — no
`useEffect`, no memoization for correctness, just stable input → stable
query key.

- **`AgentAnalyticsPage`**: `const [anchor] = useState<Dayjs>(() =>
dayjs())`, then `createDateRange(now ?? anchor)`. The `now` prop still
takes priority so Storybook snapshots remain deterministic.
- **`InsightsContent`**: Collapses `timeRange` and its anchor into a
single `TimeRangeSelection` state object. A fresh anchor is captured
only when the user changes the selected range (event handler), not on
render. Clicking the already-selected range is a no-op.
2026-03-23 18:52:10 +11:00
dependabot[bot] 3729ff46fb chore: bump the coder-modules group across 2 directories with 1 update (#23413)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore <dependency name> major version` will close this
group update PR and stop Dependabot creating any more for the specific
dependency's major version (unless you unignore this specific
dependency's major version or upgrade to it yourself)
- `@dependabot ignore <dependency name> minor version` will close this
group update PR and stop Dependabot creating any more for the specific
dependency's minor version (unless you unignore this specific
dependency's minor version or upgrade to it yourself)
- `@dependabot ignore <dependency name>` will close this group update PR
and stop Dependabot creating any more for the specific dependency
(unless you unignore this specific dependency or upgrade to it yourself)
- `@dependabot unignore <dependency name>` will remove all of the ignore
conditions of the specified dependency
- `@dependabot unignore <dependency name> <ignore condition>` will
remove the ignore condition of the specified dependency and ignore
conditions


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 00:28:19 +00:00
Danielle Maywood b87171086c refactor(site): restructure agents routing and directory layout (#23408) 2026-03-22 23:58:58 +00:00
Kerem Kacel b763b72b53 feat: add user:read scope (#23348)
Enables [23270](https://github.com/coder/coder/discussions/23270).

Makes it possible for admin users to create API tokens scoped for
reading users' data.
2026-03-22 09:06:03 -05:00
Danielle Maywood a08b6848f2 fix(site): fix desktop reconnect loop by moving connection lifecycle into hook (#23404) 2026-03-22 02:01:33 +00:00
Danielle Maywood bf702cc3b9 chore(site): update streamdown from 2.2.0 to 2.5.0 (#23407) 2026-03-21 21:50:20 -04:00
Asher 47daca6eea feat: add filtering to org members (#23334)
Continuation of https://github.com/coder/coder/pull/23067

Add filtering to the paginated org member endpoint (pretty much the same
as what I did in the previous PR with group members, except there I also
had to add pagination since it was missing).
2026-03-21 16:58:45 -08:00
Asher 4b707515c0 feat: add filtering and pagination to group members page (#23392)
Makes use of the new group members endpoint added in
https://github.com/coder/coder/pull/23067
2026-03-21 16:58:08 -08:00
Danielle Maywood ecc28a6650 fix(site): prevent infinite desktop reconnect loop on exit code 1006 (#23401) 2026-03-21 13:34:00 +00:00
Michael Suchacz cf24c59b56 feat(site): add date filtering to settings usage page (#23381)
## What

Replace the hardcoded 30-day date window on the Agents Settings Usage
page (`/agents/settings/usage`) with an interactive date-range picker.

## Why

The usage page previously showed a static 30-day lookback with no way
for admins to adjust the time window. The backend API already supports
`start_date`/`end_date` parameters — only the frontend was missing the
controls.

## How

- Reuse the existing `DateRange` picker component from Template Insights
- Store selected dates in URL search params (`startDate`/`endDate`) for
persistence across navigation
- Default to last 30 days when no params are present
- Memoize date range for stable React Query keys
- Both the user list and per-user drill-in views respect the selected
range
- Normalize exclusive end-date boundaries for display
- Preset clicks (Last 7 days, etc.) apply immediately with a single
click
- Semi-transparent loading overlay during data refetch

## Changes

- `site/src/pages/AgentsPage/SettingsPageContent.tsx` — Replace
hardcoded range with interactive picker, URL param state, memoized
params, refetch overlay
- `site/src/pages/AgentsPage/SettingsPageContent.stories.tsx` — Add
stories for date filter interaction, preset single-click, and refetch
overlay
- `site/src/pages/TemplatePage/TemplateInsightsPage/DateRange.tsx` —
Detect preset clicks and apply immediately (single-click) instead of
requiring two clicks

## Validation

- TypeScript 
- Biome lint 
- Storybook tests 13/13 
- Visual verification via Storybook 
2026-03-20 23:38:43 +01:00
Michael Suchacz a85800c90b docs: remove hardcoded AI attribution template from PR style guide (#23384)
The attribution footer in the PR style guide assumed all AI-generated
PRs come from Claude Code using Claude Sonnet 4.5. PRs can be generated
through different tools and models (e.g. Coder Agents), so a hardcoded
template is misleading.

Co-authored-by: Michael Suchacz <ibetitsmike@users.noreply.github.com>
2026-03-20 22:44:52 +01:00
396 changed files with 16584 additions and 6715 deletions
+140
View File
@@ -0,0 +1,140 @@
---
name: refine-plan
description: Iteratively refine development plans using TDD methodology. Ensures plans are clear, actionable, and include red-green-refactor cycles with proper test coverage.
---
# Refine Development Plan
## Overview
Good plans eliminate ambiguity through clear requirements, break work into clear phases, and always include refactoring to capture implementation insights.
## When to Use This Skill
| Symptom | Example |
|-----------------------------|----------------------------------------|
| Unclear acceptance criteria | No definition of "done" |
| Vague implementation | Missing concrete steps or file changes |
| Missing/undefined tests | Tests mentioned only as afterthought |
| Absent refactor phase | No plan to improve code after it works |
| Ambiguous requirements | Multiple interpretations possible |
| Missing verification | No way to confirm the change works |
## Planning Principles
### 1. Plans Must Be Actionable and Unambiguous
Every step should be concrete enough that another agent could execute it without guessing.
- ❌ "Improve error handling" → ✓ "Add try-catch to API calls in user-service.ts, return 400 with error message"
- ❌ "Update tests" → ✓ "Add test case to auth.test.ts: 'should reject expired tokens with 401'"
NEVER include thinking output or other stream-of-consciousness prose mid-plan.
### 2. Push Back on Unclear Requirements
When requirements are ambiguous, ask questions before proceeding.
### 3. Tests Define Requirements
Writing test cases forces disambiguation. Use test definition as a requirements clarification tool.
### 4. TDD is Non-Negotiable
All plans follow: **Red → Green → Refactor**. The refactor phase is MANDATORY.
## The TDD Workflow
### Red Phase: Write Failing Tests First
**Purpose:** Define success criteria through concrete test cases.
**What to test:**
- Happy path (normal usage), edge cases (boundaries, empty/null), error conditions (invalid input, failures), integration points
**Test types:**
- Unit tests: Individual functions in isolation (most tests should be these - fast, focused)
- Integration tests: Component interactions (use for critical paths)
- E2E tests: Complete workflows (use sparingly)
**Write descriptive test cases:**
**If you can't write the test, you don't understand the requirement and MUST ask for clarification.**
### Green Phase: Make Tests Pass
**Purpose:** Implement minimal working solution.
Focus on correctness first. Hardcode if needed. Add just enough logic. Resist urge to "improve" code. Run tests frequently.
### Refactor Phase: Improve the Implementation
**Purpose:** Apply insights gained during implementation.
**This phase is MANDATORY.** During implementation you'll discover better structure, repeated patterns, and simplification opportunities.
**When to Extract vs Keep Duplication:**
This is highly subjective, so use the following rules of thumb combined with good judgement:
1) Follow the "rule of three": if the exact 10+ lines are repeated verbatim 3+ times, extract it.
2) The "wrong abstraction" is harder to fix than duplication.
3) If extraction would harm readability, prefer duplication.
**Common refactorings:**
- Rename for clarity
- Simplify complex conditionals
- Extract repeated code (if meets criteria above)
- Apply design patterns
**Constraints:**
- All tests must still pass after refactoring
- Don't add new features (that's a new Red phase)
## Plan Refinement Process
### Step 1: Review Current Plan for Completeness
- [ ] Clear context explaining why
- [ ] Specific, unambiguous requirements
- [ ] Test cases defined before implementation
- [ ] Step-by-step implementation approach
- [ ] Explicit refactor phase
- [ ] Verification steps
### Step 2: Identify Gaps
Look for missing tests, vague steps, no refactor phase, ambiguous requirements, missing verification.
### Step 3: Handle Unclear Requirements
If you can't write the plan without this information, ask the user. Otherwise, make reasonable assumptions and note them in the plan.
### Step 4: Define Test Cases
For each requirement, write concrete test cases. If you struggle to write test cases, you need more clarification.
### Step 5: Structure with Red-Green-Refactor
Organize the plan into three explicit phases.
### Step 6: Add Verification Steps
Specify how to confirm the change works (automated tests + manual checks).
## Tips for Success
1. **Start with tests:** If you can't write the test, you don't understand the requirement.
2. **Be specific:** "Update API" is not a step. "Add error handling to POST /users endpoint" is.
3. **Always refactor:** Even if code looks good, ask "How could this be clearer?"
4. **Question everything:** Ambiguity is the enemy.
5. **Think in phases:** Red → Green → Refactor.
6. **Keep plans manageable:** If plan exceeds ~10 files or >5 phases, consider splitting.
---
**Remember:** A good plan makes implementation straightforward. A vague plan leads to confusion, rework, and bugs.
+2 -11
View File
@@ -177,16 +177,6 @@ Dependabot PRs are auto-generated - don't try to match their verbose style for m
Changes from https://github.com/upstream/repo/pull/XXX/
```
## Attribution Footer
For AI-generated PRs, end with:
```markdown
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
```
## Creating PRs as Draft
**IMPORTANT**: Unless explicitly told otherwise, always create PRs as drafts using the `--draft` flag:
@@ -197,11 +187,12 @@ gh pr create --draft --title "..." --body "..."
After creating the PR, encourage the user to review it before marking as ready:
```
```text
I've created draft PR #XXXX. Please review the changes and mark it as ready for review when you're satisfied.
```
This allows the user to:
- Review the code changes before requesting reviews from maintainers
- Make additional adjustments if needed
- Ensure CI passes before notifying reviewers
+3 -1
View File
@@ -45,7 +45,7 @@ jobs:
fetch-depth: 1
persist-credentials: false
- name: check changed files
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
id: filter
with:
filters: |
@@ -1119,6 +1119,8 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
with:
use-cache: false
- name: Install rcodesign
run: |
+1 -1
View File
@@ -135,7 +135,7 @@ jobs:
PR_NUMBER: ${{ steps.pr_info.outputs.PR_NUMBER }}
- name: Check changed files
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
id: filter
with:
base: ${{ github.ref }}
+2
View File
@@ -163,6 +163,8 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
with:
use-cache: false
- name: Setup Node
uses: ./.github/actions/setup-node
-113
View File
@@ -63,116 +63,3 @@ jobs:
--data "{\"content\": \"$msg\"}" \
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
trivy:
permissions:
security-events: write
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Setup Node
uses: ./.github/actions/setup-node
- name: Setup sqlc
uses: ./.github/actions/setup-sqlc
- name: Install cosign
uses: ./.github/actions/install-cosign
- name: Install syft
uses: ./.github/actions/install-syft
- name: Install yq
run: go run github.com/mikefarah/yq/v4@v4.44.3
- name: Install mockgen
run: ./.github/scripts/retry.sh -- go install go.uber.org/mock/mockgen@v0.6.0
- name: Install protoc-gen-go
run: ./.github/scripts/retry.sh -- go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.30
- name: Install protoc-gen-go-drpc
run: ./.github/scripts/retry.sh -- go install storj.io/drpc/cmd/protoc-gen-go-drpc@v0.0.34
- name: Install Protoc
run: |
# protoc must be in lockstep with our dogfood Dockerfile or the
# version in the comments will differ. This is also defined in
# ci.yaml.
set -euxo pipefail
cd dogfood/coder
mkdir -p /usr/local/bin
mkdir -p /usr/local/include
DOCKER_BUILDKIT=1 docker build . --target proto -t protoc
protoc_path=/usr/local/bin/protoc
docker run --rm --entrypoint cat protoc /tmp/bin/protoc > $protoc_path
chmod +x $protoc_path
protoc --version
# Copy the generated files to the include directory.
docker run --rm -v /usr/local/include:/target protoc cp -r /tmp/include/google /target/
ls -la /usr/local/include/google/protobuf/
stat /usr/local/include/google/protobuf/timestamp.proto
- name: Build Coder linux amd64 Docker image
id: build
run: |
set -euo pipefail
version="$(./scripts/version.sh)"
image_job="build/coder_${version}_linux_amd64.tag"
# This environment variable force make to not build packages and
# archives (which the Docker image depends on due to technical reasons
# related to concurrent FS writes).
export DOCKER_IMAGE_NO_PREREQUISITES=true
# This environment variables forces scripts/build_docker.sh to build
# the base image tag locally instead of using the cached version from
# the registry.
CODER_IMAGE_BUILD_BASE_TAG="$(CODER_IMAGE_BASE=coder-base ./scripts/image_tag.sh --version "$version")"
export CODER_IMAGE_BUILD_BASE_TAG
# We would like to use make -j here, but it doesn't work with the some recent additions
# to our code generation.
make "$image_job"
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.34.0
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
output: trivy-results.sarif
severity: "CRITICAL,HIGH"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
with:
sarif_file: trivy-results.sarif
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: trivy
path: trivy-results.sarif
retention-days: 7
- name: Send Slack notification on failure
if: ${{ failure() }}
run: |
msg="❌ Trivy Failed\n\nhttps://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
curl \
-qfsSL \
-X POST \
-H "Content-Type: application/json" \
--data "{\"content\": \"$msg\"}" \
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
+2 -1
View File
@@ -1255,7 +1255,7 @@ coderd/notifications/.gen-golden: $(wildcard coderd/notifications/testdata/*/*.g
TZ=UTC go test ./coderd/notifications -run="Test.*Golden$$" -update
touch "$@"
provisioner/terraform/testdata/.gen-golden: $(wildcard provisioner/terraform/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard provisioner/terraform/*_test.go)
provisioner/terraform/testdata/.gen-golden: $(wildcard provisioner/terraform/testdata/*/*.golden) $(wildcard provisioner/terraform/testdata/*/*/*.golden) $(GO_SRC_FILES) $(wildcard provisioner/terraform/*_test.go)
TZ=UTC go test ./provisioner/terraform -run="Test.*Golden$$" -update
touch "$@"
@@ -1343,6 +1343,7 @@ test-js: site/node_modules/.installed
test-storybook: site/node_modules/.installed
cd site/
pnpm playwright:install
pnpm exec vitest run --project=storybook
.PHONY: test-storybook
+2 -5
View File
@@ -16,7 +16,6 @@ import (
"os/user"
"path/filepath"
"slices"
"sort"
"strconv"
"strings"
"sync"
@@ -463,9 +462,7 @@ func (a *agent) runLoop() {
// messages.
ctx := a.hardCtx
defer a.logger.Info(ctx, "agent main loop exited")
retrier := retry.New(100*time.Millisecond, 10*time.Second)
retrier.Jitter = 0.5
for ; retrier.Wait(ctx); {
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
a.logger.Info(ctx, "connecting to coderd")
err := a.run()
if err == nil {
@@ -1879,7 +1876,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
}()
}
wg.Wait()
sort.Float64s(durations)
slices.Sort(durations)
durationsLength := len(durations)
switch {
case durationsLength == 0:
@@ -433,7 +433,7 @@ func convertDockerInspect(raw []byte) ([]codersdk.WorkspaceAgentContainer, []str
}
portKeys := maps.Keys(in.NetworkSettings.Ports)
// Sort the ports for deterministic output.
sort.Strings(portKeys)
slices.Sort(portKeys)
// If we see the same port bound to both ipv4 and ipv6 loopback or unspecified
// interfaces to the same container port, there is no point in adding it multiple times.
loopbackHostPortContainerPorts := make(map[int]uint16, 0)
+31 -46
View File
@@ -2,7 +2,6 @@ package agentdesktop
import (
"encoding/json"
"math"
"net/http"
"strconv"
"time"
@@ -13,6 +12,7 @@ import (
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
"github.com/coder/websocket"
)
@@ -26,9 +26,9 @@ type DesktopAction struct {
Duration *int `json:"duration,omitempty"`
ScrollAmount *int `json:"scroll_amount,omitempty"`
ScrollDirection *string `json:"scroll_direction,omitempty"`
// ScaledWidth and ScaledHeight are the coordinate space the
// model is using. When provided, coordinates are linearly
// mapped from scaled → native before dispatching.
// ScaledWidth and ScaledHeight describe the declared model-facing desktop
// geometry. When provided, input coordinates are mapped from declared space
// to native desktop pixels before dispatching.
ScaledWidth *int `json:"scaled_width,omitempty"`
ScaledHeight *int `json:"scaled_height,omitempty"`
}
@@ -144,17 +144,8 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
)
// Helper to scale a coordinate pair from the model's space to
// native display pixels.
scaleXY := func(x, y int) (int, int) {
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
x = scaleCoordinate(x, *action.ScaledWidth, cfg.Width)
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
y = scaleCoordinate(y, *action.ScaledHeight, cfg.Height)
}
return x, y
}
geometry := desktopGeometryForAction(cfg, action)
scaleXY := geometry.DeclaredPointToNative
var resp DesktopActionResponse
@@ -192,7 +183,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
resp.Output = "type action performed"
case "cursor_position":
x, y, err := a.desktop.CursorPosition(ctx)
nativeX, nativeY, err := a.desktop.CursorPosition(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Cursor position failed.",
@@ -200,6 +191,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
})
return
}
x, y := geometry.NativePointToDeclared(nativeX, nativeY)
resp.Output = "x=" + strconv.Itoa(x) + ",y=" + strconv.Itoa(y)
case "mouse_move":
@@ -447,14 +439,10 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
resp.Output = "hold_key action performed"
case "screenshot":
var opts ScreenshotOptions
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
opts.TargetWidth = *action.ScaledWidth
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
opts.TargetHeight = *action.ScaledHeight
}
result, err := a.desktop.Screenshot(ctx, opts)
result, err := a.desktop.Screenshot(ctx, ScreenshotOptions{
TargetWidth: geometry.DeclaredWidth,
TargetHeight: geometry.DeclaredHeight,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Screenshot failed.",
@@ -464,16 +452,8 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
}
resp.Output = "screenshot"
resp.ScreenshotData = result.Data
if action.ScaledWidth != nil && *action.ScaledWidth > 0 && *action.ScaledWidth != cfg.Width {
resp.ScreenshotWidth = *action.ScaledWidth
} else {
resp.ScreenshotWidth = cfg.Width
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 && *action.ScaledHeight != cfg.Height {
resp.ScreenshotHeight = *action.ScaledHeight
} else {
resp.ScreenshotHeight = cfg.Height
}
resp.ScreenshotWidth = geometry.DeclaredWidth
resp.ScreenshotHeight = geometry.DeclaredHeight
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@@ -512,6 +492,23 @@ func coordFromAction(action DesktopAction) (x, y int, err error) {
return action.Coordinate[0], action.Coordinate[1], nil
}
func desktopGeometryForAction(cfg DisplayConfig, action DesktopAction) workspacesdk.DesktopGeometry {
declaredWidth := cfg.Width
declaredHeight := cfg.Height
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
declaredWidth = *action.ScaledWidth
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
declaredHeight = *action.ScaledHeight
}
return workspacesdk.NewDesktopGeometryWithDeclared(
cfg.Width,
cfg.Height,
declaredWidth,
declaredHeight,
)
}
// missingFieldError is returned when a required field is absent from
// a DesktopAction.
type missingFieldError struct {
@@ -522,15 +519,3 @@ type missingFieldError struct {
func (e *missingFieldError) Error() string {
return "Missing \"" + e.field + "\" for " + e.action + " action."
}
// scaleCoordinate maps a coordinate from scaled → native space.
func scaleCoordinate(scaled, scaledDim, nativeDim int) int {
if scaledDim == 0 || scaledDim == nativeDim {
return scaled
}
native := (float64(scaled)+0.5)*float64(nativeDim)/float64(scaledDim) - 0.5
// Clamp to valid range.
native = math.Max(native, 0)
native = math.Min(native, float64(nativeDim-1))
return int(native)
}
+125 -16
View File
@@ -27,10 +27,12 @@ var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
// fakeDesktop is a minimal Desktop implementation for unit tests.
type fakeDesktop struct {
startErr error
cursorPos [2]int
startCfg agentdesktop.DisplayConfig
vncConnErr error
screenshotErr error
screenshotRes agentdesktop.ScreenshotResult
lastShotOpts agentdesktop.ScreenshotOptions
closed bool
// Track calls for assertions.
@@ -51,7 +53,8 @@ func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
return nil, f.vncConnErr
}
func (f *fakeDesktop) Screenshot(_ context.Context, _ agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
func (f *fakeDesktop) Screenshot(_ context.Context, opts agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
f.lastShotOpts = opts
return f.screenshotRes, f.screenshotErr
}
@@ -100,8 +103,8 @@ func (f *fakeDesktop) Type(_ context.Context, text string) error {
return nil
}
func (*fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
return 10, 20, nil
func (f *fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
return f.cursorPos[0], f.cursorPos[1], nil
}
func (f *fakeDesktop) Close() error {
@@ -135,8 +138,12 @@ func TestHandleAction_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
geometry := workspacesdk.DefaultDesktopGeometry()
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: workspacesdk.DesktopDisplayWidth, Height: workspacesdk.DesktopDisplayHeight},
startCfg: agentdesktop.DisplayConfig{
Width: geometry.NativeWidth,
Height: geometry.NativeHeight,
},
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
}
api := agentdesktop.NewAPI(logger, fake, nil)
@@ -158,11 +165,52 @@ func TestHandleAction_Screenshot(t *testing.T) {
var result agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&result)
require.NoError(t, err)
// Dimensions come from DisplayConfig, not the screenshot CLI.
assert.Equal(t, "screenshot", result.Output)
assert.Equal(t, "base64data", result.ScreenshotData)
assert.Equal(t, workspacesdk.DesktopDisplayWidth, result.ScreenshotWidth)
assert.Equal(t, workspacesdk.DesktopDisplayHeight, result.ScreenshotHeight)
assert.Equal(t, geometry.NativeWidth, result.ScreenshotWidth)
assert.Equal(t, geometry.NativeHeight, result.ScreenshotHeight)
assert.Equal(t, agentdesktop.ScreenshotOptions{
TargetWidth: geometry.NativeWidth,
TargetHeight: geometry.NativeHeight,
}, fake.lastShotOpts)
}
func TestHandleAction_ScreenshotUsesDeclaredDimensionsFromRequest(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
Action: "screenshot",
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, agentdesktop.ScreenshotOptions{TargetWidth: 1280, TargetHeight: 720}, fake.lastShotOpts)
var result agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&result)
require.NoError(t, err)
assert.Equal(t, 1280, result.ScreenshotWidth)
assert.Equal(t, 720, result.ScreenshotHeight)
}
func TestHandleAction_LeftClick(t *testing.T) {
@@ -315,7 +363,6 @@ func TestHandleAction_HoldKey(t *testing.T) {
handler.ServeHTTP(rr, req)
}()
// Wait for the timer to be created, then advance past it.
trap.MustWait(req.Context()).MustRelease(req.Context())
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
@@ -389,7 +436,6 @@ func TestHandleAction_ScrollDown(t *testing.T) {
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
// dy should be positive 5 for "down".
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
}
@@ -398,13 +444,11 @@ func TestHandleAction_CoordinateScaling(t *testing.T) {
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
// Native display is 1920x1080.
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
// Model is working in a 1280x720 coordinate space.
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
@@ -424,12 +468,43 @@ func TestHandleAction_CoordinateScaling(t *testing.T) {
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
// 640 in 1280-space → 960 in 1920-space (midpoint maps to
// midpoint).
assert.Equal(t, 960, fake.lastMove[0])
assert.Equal(t, 540, fake.lastMove[1])
}
func TestHandleAction_CoordinateScalingClampsToLastPixel(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1366
sh := 768
body := agentdesktop.DesktopAction{
Action: "mouse_move",
Coordinate: &[2]int{1365, 767},
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, 1919, fake.lastMove[0])
assert.Equal(t, 1079, fake.lastMove[1])
}
func TestClose_DelegatesToDesktop(t *testing.T) {
t.Parallel()
@@ -446,15 +521,12 @@ func TestClose_PreventsNewSessions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// After Close(), Start() will return an error because the
// underlying Desktop is closed.
fake := &fakeDesktop{}
api := agentdesktop.NewAPI(logger, fake, nil)
err := api.Close()
require.NoError(t, err)
// Simulate the closed desktop returning an error on Start().
fake.startErr = xerrors.New("desktop is closed")
rr := httptest.NewRecorder()
@@ -465,3 +537,40 @@ func TestClose_PreventsNewSessions(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rr.Code)
}
func TestHandleAction_CursorPositionReturnsDeclaredCoordinates(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
cursorPos: [2]int{960, 540},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
Action: "cursor_position",
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
var resp agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
// Native (960,540) in 1920x1080 should map to declared space in 1280x720.
assert.Equal(t, "x=640,y=360", resp.Output)
}
+1 -1
View File
@@ -111,7 +111,7 @@ func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight))
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopNativeWidth, workspacesdk.DesktopNativeHeight))
stdout, err := cmd.StdoutPipe()
if err != nil {
sessionCancel()
+135 -22
View File
@@ -14,6 +14,7 @@ import (
"syscall"
"github.com/google/uuid"
"github.com/spf13/afero"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
@@ -319,8 +320,14 @@ func (api *API) writeFile(ctx context.Context, r *http.Request, path string) (HT
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
}
resolved, err := api.resolveSymlink(path)
if err != nil {
return http.StatusInternalServerError, xerrors.Errorf("resolve symlink %q: %w", path, err)
}
path = resolved
dir := filepath.Dir(path)
err := api.filesystem.MkdirAll(dir, 0o755)
err = api.filesystem.MkdirAll(dir, 0o755)
if err != nil {
status := http.StatusInternalServerError
switch {
@@ -410,6 +417,12 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
return http.StatusBadRequest, xerrors.New("must specify at least one edit")
}
resolved, err := api.resolveSymlink(path)
if err != nil {
return http.StatusInternalServerError, xerrors.Errorf("resolve symlink %q: %w", path, err)
}
path = resolved
f, err := api.filesystem.Open(path)
if err != nil {
status := http.StatusInternalServerError
@@ -510,6 +523,52 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode,
return 0, nil
}
// resolveSymlink resolves a path through any symlinks so that
// subsequent operations (such as atomic rename) target the real
// file instead of replacing the symlink itself.
//
// The filesystem must implement afero.Lstater and afero.LinkReader
// for resolution to occur; if it does not (e.g. MemMapFs), the
// path is returned unchanged.
func (api *API) resolveSymlink(path string) (string, error) {
const maxDepth = 10
lstater, hasLstat := api.filesystem.(afero.Lstater)
if !hasLstat {
return path, nil
}
reader, hasReadlink := api.filesystem.(afero.LinkReader)
if !hasReadlink {
return path, nil
}
for range maxDepth {
info, _, err := lstater.LstatIfPossible(path)
if err != nil {
// If the file does not exist yet (new file write),
// there is nothing to resolve.
if errors.Is(err, os.ErrNotExist) {
return path, nil
}
return "", err
}
if info.Mode()&os.ModeSymlink == 0 {
return path, nil
}
target, err := reader.ReadlinkIfPossible(path)
if err != nil {
return "", err
}
if !filepath.IsAbs(target) {
target = filepath.Join(filepath.Dir(path), target)
}
path = target
}
return "", xerrors.Errorf("too many levels of symlinks resolving %q", path)
}
// fuzzyReplace attempts to find `search` inside `content` and replace it
// with `replace`. It uses a cascading match strategy inspired by
// openai/codex's apply_patch:
@@ -567,30 +626,15 @@ func fuzzyReplace(content string, edit workspacesdk.FileEdit) (string, error) {
}
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, trimRight); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimRight); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
if result, matched, err := fuzzyReplaceLines(contentLines, searchLines, replace, trimRight, edit.ReplaceAll); matched {
return result, err
}
// Pass 3 trim all leading and trailing whitespace
// (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, trimAll); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimAll); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
// (indentation-tolerant). The replacement is inserted verbatim;
// callers must provide correctly indented replacement text.
if result, matched, err := fuzzyReplaceLines(contentLines, searchLines, replace, trimAll, edit.ReplaceAll); matched {
return result, err
}
return "", xerrors.New("search string not found in file. Verify the search " +
@@ -653,3 +697,72 @@ func spliceLines(contentLines []string, start, end int, replacement string) stri
}
return b.String()
}
// fuzzyReplaceLines handles fuzzy matching passes (2 and 3) for
// fuzzyReplace. When replaceAll is false and there are multiple
// matches, an error is returned. When replaceAll is true, all
// non-overlapping matches are replaced.
//
// Returns (result, true, nil) on success, ("", false, nil) when
// searchLines don't match at all, or ("", true, err) when the match
// is ambiguous.
//
//nolint:revive // replaceAll is a direct pass-through of the user's flag, not a control coupling.
func fuzzyReplaceLines(
contentLines, searchLines []string,
replace string,
eq func(a, b string) bool,
replaceAll bool,
) (string, bool, error) {
start, end, ok := seekLines(contentLines, searchLines, eq)
if !ok {
return "", false, nil
}
if !replaceAll {
if count := countLineMatches(contentLines, searchLines, eq); count > 1 {
return "", true, xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
return spliceLines(contentLines, start, end, replace), true, nil
}
// Replace all: collect all match positions, then apply from last
// to first to preserve indices.
type lineMatch struct{ start, end int }
var matches []lineMatch
for i := 0; i <= len(contentLines)-len(searchLines); {
found := true
for j, sLine := range searchLines {
if !eq(contentLines[i+j], sLine) {
found = false
break
}
}
if found {
matches = append(matches, lineMatch{i, i + len(searchLines)})
i += len(searchLines) // skip past this match
} else {
i++
}
}
// Apply replacements from last to first.
repLines := strings.SplitAfter(replace, "\n")
for i := len(matches) - 1; i >= 0; i-- {
m := matches[i]
newLines := make([]string, 0, m.start+len(repLines)+(len(contentLines)-m.end))
newLines = append(newLines, contentLines[:m.start]...)
newLines = append(newLines, repLines...)
newLines = append(newLines, contentLines[m.end:]...)
contentLines = newLines
}
var b strings.Builder
for _, l := range contentLines {
_, _ = b.WriteString(l)
}
return b.String(), true, nil
}
+139
View File
@@ -881,6 +881,43 @@ func TestEditFiles(t *testing.T) {
},
expected: map[string]string{filepath.Join(tmpdir, "ra-exact"): "qux bar qux baz qux"},
},
{
// replace_all with fuzzy trailing-whitespace match.
name: "ReplaceAllFuzzyTrailing",
contents: map[string]string{filepath.Join(tmpdir, "ra-fuzzy-trail"): "hello \nworld\nhello \nagain"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ra-fuzzy-trail"),
Edits: []workspacesdk.FileEdit{
{
Search: "hello\n",
Replace: "bye\n",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "ra-fuzzy-trail"): "bye\nworld\nbye\nagain"},
},
{
// replace_all with fuzzy indent match (pass 3).
name: "ReplaceAllFuzzyIndent",
contents: map[string]string{filepath.Join(tmpdir, "ra-fuzzy-indent"): "\t\talpha\n\t\tbeta\n\t\talpha\n\t\tgamma"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ra-fuzzy-indent"),
Edits: []workspacesdk.FileEdit{
{
// Search uses different indentation (spaces instead of tabs).
Search: " alpha\n",
Replace: "\t\tREPLACED\n",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "ra-fuzzy-indent"): "\t\tREPLACED\n\t\tbeta\n\t\tREPLACED\n\t\tgamma"},
},
{
name: "MixedWhitespaceMultiline",
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
@@ -1395,3 +1432,105 @@ func TestReadFileLines(t *testing.T) {
})
}
}
func TestWriteFile_FollowsSymlinks(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("symlinks are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
// Create a real file and a symlink pointing to it.
realPath := filepath.Join(dir, "real.txt")
err := afero.WriteFile(osFs, realPath, []byte("original"), 0o644)
require.NoError(t, err)
linkPath := filepath.Join(dir, "link.txt")
err = os.Symlink(realPath, linkPath)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// Write through the symlink.
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
fmt.Sprintf("/write-file?path=%s", linkPath),
bytes.NewReader([]byte("updated")))
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
// The symlink must still be a symlink.
fi, err := os.Lstat(linkPath)
require.NoError(t, err)
require.NotZero(t, fi.Mode()&os.ModeSymlink, "symlink was replaced")
// The real file must have the new content.
data, err := os.ReadFile(realPath)
require.NoError(t, err)
require.Equal(t, "updated", string(data))
}
func TestEditFiles_FollowsSymlinks(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("symlinks are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
// Create a real file and a symlink pointing to it.
realPath := filepath.Join(dir, "real.txt")
err := afero.WriteFile(osFs, realPath, []byte("hello world"), 0o644)
require.NoError(t, err)
linkPath := filepath.Join(dir, "link.txt")
err = os.Symlink(realPath, linkPath)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
body := workspacesdk.FileEditRequest{
Files: []workspacesdk.FileEdits{
{
Path: linkPath,
Edits: []workspacesdk.FileEdit{
{
Search: "hello",
Replace: "goodbye",
},
},
},
},
}
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
err = enc.Encode(body)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/edit-files", buf)
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
// The symlink must still be a symlink.
fi, err := os.Lstat(linkPath)
require.NoError(t, err)
require.NotZero(t, fi.Mode()&os.ModeSymlink, "symlink was replaced")
// The real file must have the edited content.
data, err := os.ReadFile(realPath)
require.NoError(t, err)
require.Equal(t, "goodbye world", string(data))
}
+2 -2
View File
@@ -1,7 +1,7 @@
package agentgit
import (
"sort"
"slices"
"sync"
"github.com/google/uuid"
@@ -99,7 +99,7 @@ func (ps *PathStore) GetPaths(chatID uuid.UUID) []string {
for p := range m {
out = append(out, p)
}
sort.Strings(out)
slices.Sort(out)
return out
}
+2 -2
View File
@@ -4,7 +4,7 @@ import (
"context"
"os"
"path/filepath"
"sort"
"slices"
"testing"
"github.com/stretchr/testify/require"
@@ -228,6 +228,6 @@ func resultPaths(results []filefinder.Result) []string {
for i, r := range results {
paths[i] = r.Path
}
sort.Strings(paths)
slices.Sort(paths)
return paths
}
+1 -1
View File
@@ -104,7 +104,7 @@ func (b *Builder) Build(inv *serpent.Invocation) (log slog.Logger, closeLog func
addSinkIfProvided := func(sinkFn func(io.Writer) slog.Sink, loc string) error {
switch loc {
case "", "/dev/null":
case "":
case "/dev/stdout":
sinks = append(sinks, sinkFn(inv.Stdout))
+3 -3
View File
@@ -5,7 +5,7 @@ import (
"os/exec"
"path/filepath"
"runtime"
"sort"
"slices"
"strings"
"testing"
@@ -376,8 +376,8 @@ func Test_sshConfigOptions_addOption(t *testing.T) {
return
}
require.NoError(t, err)
sort.Strings(tt.Expect)
sort.Strings(o.sshOptions)
slices.Sort(tt.Expect)
slices.Sort(o.sshOptions)
require.Equal(t, tt.Expect, o.sshOptions)
})
}
+12 -31
View File
@@ -732,7 +732,6 @@ func (r *RootCmd) scaletestCreateWorkspaces() *serpent.Command {
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
client.Trace = tracingFlags.tracePropagate
defer func() {
// Allow time for traces to flush even if command context is
// canceled. This is a no-op if tracing is not enabled.
@@ -1080,7 +1079,6 @@ func (r *RootCmd) scaletestWorkspaceUpdates() *serpent.Command {
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
client.Trace = tracingFlags.tracePropagate
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
@@ -1339,7 +1337,6 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
client.Trace = tracingFlags.tracePropagate
defer func() {
// Allow time for traces to flush even if command context is
// canceled. This is a no-op if tracing is not enabled.
@@ -1404,9 +1401,6 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
// Setup our workspace agent connection.
config := workspacetraffic.Config{
AgentID: agent.ID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agent.Name,
BytesPerTick: bytesPerTick,
Duration: strategy.timeout,
TickInterval: tickInterval,
@@ -1446,35 +1440,24 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
_, _ = fmt.Fprintln(inv.Stderr, "Running load test...")
testCtx, testCancel := strategy.toContext(ctx)
defer testCancel()
runErr := th.Run(testCtx)
res := th.Results()
// Write full results to the configured output destination
// (default: text to stdout via --output flag).
// for _, o := range outputs {
_ = outputs
// if writeErr := o.write(res, os.Stdout); writeErr != nil {
// _, _ = fmt.Fprintf(os.Stderr, "Failed to write output %q to %q: %v\n", o.format, o.path, writeErr)
// }
// }
// Always write a summary to stderr for visibility in
// container logs. Full output goes to --output above.
// Limit to 10 failures to avoid exceeding kubelet log
// rotation limits.
res.PrintSummary(os.Stderr, 10)
if runErr != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", runErr)
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// Check for interrupt after printing results so we always
// have visibility into what happened.
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
@@ -1580,7 +1563,6 @@ func (r *RootCmd) scaletestDashboard() *serpent.Command {
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
client.Trace = tracingFlags.tracePropagate
tracer := tracerProvider.Tracer(scaletestTracerName)
outputs, err := output.parse()
if err != nil {
@@ -1818,7 +1800,6 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
client.Trace = tracingFlags.tracePropagate
tracer := tracerProvider.Tracer(scaletestTracerName)
setupBarrier := new(sync.WaitGroup)
+3 -3
View File
@@ -24,7 +24,7 @@ import (
"os/user"
"path/filepath"
"regexp"
"sort"
"slices"
"strconv"
"strings"
"sync"
@@ -2291,7 +2291,7 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg
ep := embeddedpostgres.NewDatabase(
embeddedpostgres.DefaultConfig().
Version(embeddedpostgres.V13).
Version(embeddedpostgres.V16).
BinariesPath(filepath.Join(cfg.PostgresPath(), "bin")).
// Default BinaryRepositoryURL repo1.maven.org is flaky.
BinaryRepositoryURL("https://repo.maven.apache.org/maven2").
@@ -2825,7 +2825,7 @@ func ReadExternalAuthProvidersFromEnv(environ []string) ([]codersdk.ExternalAuth
// parsing of `GITAUTH` environment variables.
func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]codersdk.ExternalAuthConfig, error) {
// The index numbers must be in-order.
sort.Strings(environ)
slices.Sort(environ)
var providers []codersdk.ExternalAuthConfig
for _, v := range serpent.ParseEnviron(environ, prefix) {
+3 -3
View File
@@ -7,7 +7,7 @@ import (
"io"
"os"
"path/filepath"
"sort"
"slices"
"golang.org/x/exp/maps"
"golang.org/x/xerrors"
@@ -31,7 +31,7 @@ func (*RootCmd) templateInit() *serpent.Command {
for _, ex := range exampleList {
templateIDs = append(templateIDs, ex.ID)
}
sort.Strings(templateIDs)
slices.Sort(templateIDs)
cmd := &serpent.Command{
Use: "init [directory]",
Short: "Get started with a templated template.",
@@ -50,7 +50,7 @@ func (*RootCmd) templateInit() *serpent.Command {
optsToID[name] = example.ID
}
opts := maps.Keys(optsToID)
sort.Strings(opts)
slices.Sort(opts)
_, _ = fmt.Fprintln(
inv.Stdout,
pretty.Sprint(
+2 -2
View File
@@ -4,7 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"sort"
"slices"
"testing"
"github.com/stretchr/testify/require"
@@ -47,7 +47,7 @@ func TestTemplateList(t *testing.T) {
// expect that templates are listed alphabetically
templatesList := []string{firstTemplate.Name, secondTemplate.Name}
sort.Strings(templatesList)
slices.Sort(templatesList)
require.NoError(t, <-errC)
@@ -6,7 +6,7 @@ USAGE:
List all organization members
OPTIONS:
-c, --column [username|name|user id|organization id|created at|updated at|organization roles] (default: username,organization roles)
-c, --column [username|name|last seen at|user created at|user updated at|user id|organization id|created at|updated at|organization roles] (default: username,organization roles)
Columns to display in table output.
-o, --output table|json (default: table)
-12
View File
@@ -195,18 +195,6 @@ autobuildPollInterval: 1m0s
# Interval to poll for hung and pending jobs and automatically terminate them.
# (default: 1m0s, type: duration)
jobHangDetectorInterval: 1m0s
# Number of querier workers for the PG coordinator. 0 uses the default.
# (default: 0, type: int)
tailnetQuerierWorkers: 0
# Number of binder workers for the PG coordinator. 0 uses the default.
# (default: 0, type: int)
tailnetBinderWorkers: 0
# Number of tunneler workers for the PG coordinator. 0 uses the default.
# (default: 0, type: int)
tailnetTunnelerWorkers: 0
# Number of handshaker workers for the PG coordinator. 0 uses the default.
# (default: 0, type: int)
tailnetHandshakerWorkers: 0
introspection:
statsCollection:
usageStats:
+2 -3
View File
@@ -4,7 +4,6 @@ import (
"fmt"
"os"
"slices"
"sort"
"strings"
"time"
@@ -194,7 +193,7 @@ func joinScopes(scopes []codersdk.APIKeyScope) string {
return ""
}
vals := slice.ToStrings(scopes)
sort.Strings(vals)
slices.Sort(vals)
return strings.Join(vals, ", ")
}
@@ -206,7 +205,7 @@ func joinAllowList(entries []codersdk.APIAllowListTarget) string {
for i, entry := range entries {
vals[i] = entry.String()
}
sort.Strings(vals)
slices.Sort(vals)
return strings.Join(vals, ", ")
}
-263
View File
@@ -1,263 +0,0 @@
package agentconnectionbatcher
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/quartz"
)
const (
// defaultBatchSize is the maximum number of agent connection updates
// to batch before forcing a flush. With one entry per agent, this
// accommodates 500 concurrently connected agents per batch.
defaultBatchSize = 500
// defaultChannelBufferMultiplier is the multiplier for the channel
// buffer size relative to the batch size. A 5x multiplier provides
// significant headroom for bursts while the batch is being flushed.
defaultChannelBufferMultiplier = 5
// defaultFlushInterval is how frequently to flush batched connection
// updates to the database. 5 seconds provides a good balance between
// reducing database load and keeping connection state reasonably
// current.
defaultFlushInterval = 5 * time.Second
// finalFlushTimeout is the timeout for the final flush when the
// batcher is shutting down.
finalFlushTimeout = 15 * time.Second
)
// Update represents a single agent connection state update to be batched.
type Update struct {
ID uuid.UUID
FirstConnectedAt sql.NullTime
LastConnectedAt sql.NullTime
LastConnectedReplicaID uuid.NullUUID
DisconnectedAt sql.NullTime
UpdatedAt time.Time
}
// Batcher accumulates agent connection updates and periodically flushes
// them to the database in a single batch query. This reduces per-heartbeat
// database write pressure from O(n) queries to O(1).
type Batcher struct {
store database.Store
log slog.Logger
updateCh chan Update
batch map[uuid.UUID]Update
maxBatchSize int
clock quartz.Clock
timer *quartz.Timer
interval time.Duration
ctx context.Context
cancel context.CancelFunc
done chan struct{}
}
// Option is a functional option for configuring a Batcher.
type Option func(b *Batcher)
// WithBatchSize sets the maximum number of updates to accumulate before
// forcing a flush.
func WithBatchSize(size int) Option {
return func(b *Batcher) {
b.maxBatchSize = size
}
}
// WithInterval sets how frequently the batcher flushes to the database.
func WithInterval(d time.Duration) Option {
return func(b *Batcher) {
b.interval = d
}
}
// WithLogger sets the logger for the batcher.
func WithLogger(log slog.Logger) Option {
return func(b *Batcher) {
b.log = log
}
}
// WithClock sets the clock for the batcher, useful for testing.
func WithClock(clock quartz.Clock) Option {
return func(b *Batcher) {
b.clock = clock
}
}
// New creates a new Batcher and starts its background processing loop.
// The provided context controls the lifetime of the batcher.
func New(ctx context.Context, store database.Store, opts ...Option) *Batcher {
b := &Batcher{
store: store,
done: make(chan struct{}),
log: slog.Logger{},
clock: quartz.NewReal(),
}
for _, opt := range opts {
opt(b)
}
if b.interval == 0 {
b.interval = defaultFlushInterval
}
if b.maxBatchSize == 0 {
b.maxBatchSize = defaultBatchSize
}
b.timer = b.clock.NewTimer(b.interval)
channelSize := b.maxBatchSize * defaultChannelBufferMultiplier
b.updateCh = make(chan Update, channelSize)
b.batch = make(map[uuid.UUID]Update)
b.ctx, b.cancel = context.WithCancel(ctx)
go func() {
b.run(b.ctx)
close(b.done)
}()
return b
}
// Close cancels the batcher context and waits for the final flush to
// complete.
func (b *Batcher) Close() {
b.cancel()
if b.timer != nil {
b.timer.Stop()
}
<-b.done
}
// Add enqueues an agent connection update for batching. If the internal
// channel is full, the update is dropped and a warning is logged.
func (b *Batcher) Add(u Update) {
select {
case b.updateCh <- u:
default:
b.log.Warn(context.Background(), "connection batcher channel full, dropping update",
slog.F("agent_id", u.ID),
)
}
}
func (b *Batcher) processUpdate(u Update) {
existing, exists := b.batch[u.ID]
if exists && u.UpdatedAt.Before(existing.UpdatedAt) {
return
}
b.batch[u.ID] = u
}
func (b *Batcher) run(ctx context.Context) {
//nolint:gocritic // System-level batch operation for agent connections.
authCtx := dbauthz.AsSystemRestricted(ctx)
for {
select {
case u := <-b.updateCh:
b.processUpdate(u)
if len(b.batch) >= b.maxBatchSize {
b.flush(authCtx)
b.timer.Reset(b.interval, "connectionBatcher", "capacityFlush")
}
case <-b.timer.C:
b.flush(authCtx)
b.timer.Reset(b.interval, "connectionBatcher", "scheduledFlush")
case <-ctx.Done():
b.log.Debug(ctx, "context done, flushing before exit")
ctxTimeout, cancel := context.WithTimeout(context.Background(), finalFlushTimeout)
defer cancel() //nolint:revive // Returning after this.
//nolint:gocritic // System-level batch operation for agent connections.
b.flush(dbauthz.AsSystemRestricted(ctxTimeout))
return
}
}
}
func (b *Batcher) flush(ctx context.Context) {
count := len(b.batch)
if count == 0 {
return
}
b.log.Debug(ctx, "flushing connection batch", slog.F("count", count))
var (
ids = make([]uuid.UUID, 0, count)
firstConnectedAt = make([]time.Time, 0, count)
lastConnectedAt = make([]time.Time, 0, count)
lastConnectedReplicaID = make([]uuid.UUID, 0, count)
disconnectedAt = make([]time.Time, 0, count)
updatedAt = make([]time.Time, 0, count)
)
for _, u := range b.batch {
ids = append(ids, u.ID)
firstConnectedAt = append(firstConnectedAt, nullTimeToTime(u.FirstConnectedAt))
lastConnectedAt = append(lastConnectedAt, nullTimeToTime(u.LastConnectedAt))
lastConnectedReplicaID = append(lastConnectedReplicaID, nullUUIDToUUID(u.LastConnectedReplicaID))
disconnectedAt = append(disconnectedAt, nullTimeToTime(u.DisconnectedAt))
updatedAt = append(updatedAt, u.UpdatedAt)
}
// Clear batch before the DB call. Losing a batch of heartbeat
// timestamps is acceptable; the next heartbeat will update them.
b.batch = make(map[uuid.UUID]Update)
err := b.store.BatchUpdateWorkspaceAgentConnections(ctx, database.BatchUpdateWorkspaceAgentConnectionsParams{
ID: ids,
FirstConnectedAt: firstConnectedAt,
LastConnectedAt: lastConnectedAt,
LastConnectedReplicaID: lastConnectedReplicaID,
DisconnectedAt: disconnectedAt,
UpdatedAt: updatedAt,
})
if err != nil {
if database.IsQueryCanceledError(err) {
b.log.Debug(ctx, "query canceled, skipping connection batch update")
return
}
b.log.Error(ctx, "failed to batch update agent connections", slog.Error(err))
return
}
b.log.Debug(ctx, "connection batch flush complete", slog.F("count", count))
}
// nullTimeToTime converts a sql.NullTime to a time.Time. When the
// NullTime is not valid, the zero time is returned which PostgreSQL
// will store as the epoch. The batch query uses unnest over plain
// time arrays, so we cannot pass NULL directly.
func nullTimeToTime(nt sql.NullTime) time.Time {
if nt.Valid {
return nt.Time
}
return time.Time{}
}
// nullUUIDToUUID converts a uuid.NullUUID to a uuid.UUID. When the
// NullUUID is not valid, uuid.Nil is returned.
func nullUUIDToUUID(nu uuid.NullUUID) uuid.UUID {
if nu.Valid {
return nu.UUID
}
return uuid.Nil
}
+1 -1
View File
@@ -773,7 +773,7 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
}
if statusResp.Status != agentapisdk.StatusStable {
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
Message: "Task app is not ready to accept input.",
Detail: fmt.Sprintf("Status: %s", statusResp.Status),
})
+5
View File
@@ -789,6 +789,11 @@ func TestTasks(t *testing.T) {
})
require.Error(t, err, "wanted error due to bad status")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "not ready to accept input")
statusResponse = agentapisdk.StatusStable
//nolint:tparallel // Not intended to run in parallel.
+162 -35
View File
@@ -163,6 +163,57 @@ const docTemplate = `{
]
}
},
"/aibridge/sessions": {
"get": {
"produces": [
"application/json"
],
"tags": [
"AI Bridge"
],
"summary": "List AI Bridge sessions",
"operationId": "list-ai-bridge-sessions",
"parameters": [
{
"type": "string",
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
"name": "q",
"in": "query"
},
{
"type": "integer",
"description": "Page limit",
"name": "limit",
"in": "query"
},
{
"type": "string",
"description": "Cursor pagination after session ID (cannot be used with offset)",
"name": "after_session_id",
"in": "query"
},
{
"type": "integer",
"description": "Offset pagination (cannot be used with after_session_id)",
"name": "offset",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/appearance": {
"get": {
"produces": [
@@ -4082,6 +4133,19 @@ const docTemplate = `{
"in": "path",
"required": true
},
{
"type": "string",
"description": "Member search query",
"name": "q",
"in": "query"
},
{
"type": "string",
"format": "uuid",
"description": "After ID",
"name": "after_id",
"in": "query"
},
{
"type": "integer",
"description": "Page limit, if 0 returns all members",
@@ -7958,29 +8022,6 @@ const docTemplate = `{
]
}
},
"/users/me/session/token-to-cookie": {
"post": {
"description": "Converts the current session token into a Set-Cookie response.\nThis is used by embedded iframes (e.g. VS Code chat) that\nreceive a session token out-of-band via postMessage but need\ncookie-based auth for WebSocket connections.",
"tags": [
"Authorization"
],
"summary": "Set session token cookie",
"operationId": "set-session-token-cookie",
"responses": {
"204": {
"description": "No Content"
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
}
},
"/users/oauth2/github/callback": {
"get": {
"tags": [
@@ -12788,6 +12829,20 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeListSessionsResponse": {
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"sessions": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeSession"
}
}
}
},
"codersdk.AIBridgeOpenAIConfig": {
"type": "object",
"properties": {
@@ -12840,6 +12895,64 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeSession": {
"type": "object",
"properties": {
"client": {
"type": "string"
},
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string"
},
"initiator": {
"$ref": "#/definitions/codersdk.MinimalUser"
},
"last_prompt": {
"type": "string"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"models": {
"type": "array",
"items": {
"type": "string"
}
},
"providers": {
"type": "array",
"items": {
"type": "string"
}
},
"started_at": {
"type": "string",
"format": "date-time"
},
"threads": {
"type": "integer"
},
"token_usage_summary": {
"$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary"
}
}
},
"codersdk.AIBridgeSessionTokenUsageSummary": {
"type": "object",
"properties": {
"input_tokens": {
"type": "integer"
},
"output_tokens": {
"type": "integer"
}
}
},
"codersdk.AIBridgeTokenUsage": {
"type": "object",
"properties": {
@@ -15273,18 +15386,6 @@ const docTemplate = `{
"swagger": {
"$ref": "#/definitions/codersdk.SwaggerConfig"
},
"tailnet_binder_workers": {
"type": "integer"
},
"tailnet_handshaker_workers": {
"type": "integer"
},
"tailnet_querier_workers": {
"type": "integer"
},
"tailnet_tunneler_workers": {
"type": "integer"
},
"telemetry": {
"$ref": "#/definitions/codersdk.TelemetryConfig"
},
@@ -17325,6 +17426,13 @@ const docTemplate = `{
"$ref": "#/definitions/codersdk.SlimRole"
}
},
"last_seen_at": {
"type": "string",
"format": "date-time"
},
"login_type": {
"$ref": "#/definitions/codersdk.LoginType"
},
"name": {
"type": "string"
},
@@ -17338,14 +17446,33 @@ const docTemplate = `{
"$ref": "#/definitions/codersdk.SlimRole"
}
},
"status": {
"enum": [
"active",
"suspended"
],
"allOf": [
{
"$ref": "#/definitions/codersdk.UserStatus"
}
]
},
"updated_at": {
"type": "string",
"format": "date-time"
},
"user_created_at": {
"type": "string",
"format": "date-time"
},
"user_id": {
"type": "string",
"format": "uuid"
},
"user_updated_at": {
"type": "string",
"format": "date-time"
},
"username": {
"type": "string"
}
+155 -33
View File
@@ -136,6 +136,53 @@
]
}
},
"/aibridge/sessions": {
"get": {
"produces": ["application/json"],
"tags": ["AI Bridge"],
"summary": "List AI Bridge sessions",
"operationId": "list-ai-bridge-sessions",
"parameters": [
{
"type": "string",
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
"name": "q",
"in": "query"
},
{
"type": "integer",
"description": "Page limit",
"name": "limit",
"in": "query"
},
{
"type": "string",
"description": "Cursor pagination after session ID (cannot be used with offset)",
"name": "after_session_id",
"in": "query"
},
{
"type": "integer",
"description": "Offset pagination (cannot be used with after_session_id)",
"name": "offset",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/appearance": {
"get": {
"produces": ["application/json"],
@@ -3603,6 +3650,19 @@
"in": "path",
"required": true
},
{
"type": "string",
"description": "Member search query",
"name": "q",
"in": "query"
},
{
"type": "string",
"format": "uuid",
"description": "After ID",
"name": "after_id",
"in": "query"
},
{
"type": "integer",
"description": "Page limit, if 0 returns all members",
@@ -7051,27 +7111,6 @@
]
}
},
"/users/me/session/token-to-cookie": {
"post": {
"description": "Converts the current session token into a Set-Cookie response.\nThis is used by embedded iframes (e.g. VS Code chat) that\nreceive a session token out-of-band via postMessage but need\ncookie-based auth for WebSocket connections.",
"tags": ["Authorization"],
"summary": "Set session token cookie",
"operationId": "set-session-token-cookie",
"responses": {
"204": {
"description": "No Content"
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
}
},
"/users/oauth2/github/callback": {
"get": {
"tags": ["Users"],
@@ -11376,6 +11415,20 @@
}
}
},
"codersdk.AIBridgeListSessionsResponse": {
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"sessions": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeSession"
}
}
}
},
"codersdk.AIBridgeOpenAIConfig": {
"type": "object",
"properties": {
@@ -11428,6 +11481,64 @@
}
}
},
"codersdk.AIBridgeSession": {
"type": "object",
"properties": {
"client": {
"type": "string"
},
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string"
},
"initiator": {
"$ref": "#/definitions/codersdk.MinimalUser"
},
"last_prompt": {
"type": "string"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"models": {
"type": "array",
"items": {
"type": "string"
}
},
"providers": {
"type": "array",
"items": {
"type": "string"
}
},
"started_at": {
"type": "string",
"format": "date-time"
},
"threads": {
"type": "integer"
},
"token_usage_summary": {
"$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary"
}
}
},
"codersdk.AIBridgeSessionTokenUsageSummary": {
"type": "object",
"properties": {
"input_tokens": {
"type": "integer"
},
"output_tokens": {
"type": "integer"
}
}
},
"codersdk.AIBridgeTokenUsage": {
"type": "object",
"properties": {
@@ -13780,18 +13891,6 @@
"swagger": {
"$ref": "#/definitions/codersdk.SwaggerConfig"
},
"tailnet_binder_workers": {
"type": "integer"
},
"tailnet_handshaker_workers": {
"type": "integer"
},
"tailnet_querier_workers": {
"type": "integer"
},
"tailnet_tunneler_workers": {
"type": "integer"
},
"telemetry": {
"$ref": "#/definitions/codersdk.TelemetryConfig"
},
@@ -15752,6 +15851,13 @@
"$ref": "#/definitions/codersdk.SlimRole"
}
},
"last_seen_at": {
"type": "string",
"format": "date-time"
},
"login_type": {
"$ref": "#/definitions/codersdk.LoginType"
},
"name": {
"type": "string"
},
@@ -15765,14 +15871,30 @@
"$ref": "#/definitions/codersdk.SlimRole"
}
},
"status": {
"enum": ["active", "suspended"],
"allOf": [
{
"$ref": "#/definitions/codersdk.UserStatus"
}
]
},
"updated_at": {
"type": "string",
"format": "date-time"
},
"user_created_at": {
"type": "string",
"format": "date-time"
},
"user_id": {
"type": "string",
"format": "uuid"
},
"user_updated_at": {
"type": "string",
"format": "date-time"
},
"username": {
"type": "string"
}
-186
View File
@@ -1,186 +0,0 @@
package chattool_test
import (
"context"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/quartz"
)
func TestComputerUseTool_Info(t *testing.T) {
t.Parallel()
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, nil, quartz.NewReal())
info := tool.Info()
assert.Equal(t, "computer", info.Name)
assert.NotEmpty(t, info.Description)
}
func TestComputerUseProviderTool(t *testing.T) {
t.Parallel()
def := chattool.ComputerUseProviderTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight)
pdt, ok := def.(fantasy.ProviderDefinedTool)
require.True(t, ok, "ComputerUseProviderTool should return a ProviderDefinedTool")
assert.Contains(t, pdt.ID, "computer")
assert.Equal(t, "computer", pdt.Name)
// Verify display dimensions are passed through.
assert.Equal(t, int64(workspacesdk.DesktopDisplayWidth), pdt.Args["display_width_px"])
assert.Equal(t, int64(workspacesdk.DesktopDisplayHeight), pdt.Args["display_height_px"])
}
func TestComputerUseTool_Run_Screenshot(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().ExecuteDesktopAction(
gomock.Any(),
gomock.Any(),
).Return(workspacesdk.DesktopActionResponse{
Output: "screenshot",
ScreenshotData: "base64png",
ScreenshotWidth: 1024,
ScreenshotHeight: 768,
}, nil)
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-1",
Name: "computer",
Input: `{"action":"screenshot"}`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "image", resp.Type)
assert.Equal(t, "image/png", resp.MediaType)
assert.Equal(t, []byte("base64png"), resp.Data)
assert.False(t, resp.IsError)
}
func TestComputerUseTool_Run_LeftClick(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
// Expect the action call first.
mockConn.EXPECT().ExecuteDesktopAction(
gomock.Any(),
gomock.Any(),
).Return(workspacesdk.DesktopActionResponse{
Output: "left_click performed",
}, nil)
// Then expect a screenshot (auto-screenshot after action).
mockConn.EXPECT().ExecuteDesktopAction(
gomock.Any(),
gomock.Any(),
).Return(workspacesdk.DesktopActionResponse{
Output: "screenshot",
ScreenshotData: "after-click",
ScreenshotWidth: 1024,
ScreenshotHeight: 768,
}, nil)
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-2",
Name: "computer",
Input: `{"action":"left_click","coordinate":[100,200]}`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "image", resp.Type)
assert.Equal(t, []byte("after-click"), resp.Data)
}
func TestComputerUseTool_Run_Wait(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
// Expect a screenshot after the wait completes.
mockConn.EXPECT().ExecuteDesktopAction(
gomock.Any(),
gomock.Any(),
).Return(workspacesdk.DesktopActionResponse{
Output: "screenshot",
ScreenshotData: "after-wait",
ScreenshotWidth: 1024,
ScreenshotHeight: 768,
}, nil)
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-3",
Name: "computer",
Input: `{"action":"wait","duration":10}`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "image", resp.Type)
assert.Equal(t, "image/png", resp.MediaType)
assert.Equal(t, []byte("after-wait"), resp.Data)
assert.False(t, resp.IsError)
}
func TestComputerUseTool_Run_ConnError(t *testing.T) {
t.Parallel()
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return nil, xerrors.New("workspace not available")
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-4",
Name: "computer",
Input: `{"action":"screenshot"}`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "workspace not available")
}
func TestComputerUseTool_Run_InvalidInput(t *testing.T) {
t.Parallel()
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return nil, xerrors.New("should not be called")
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-5",
Name: "computer",
Input: `{invalid json`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "invalid computer use input")
}
+49 -62
View File
@@ -45,14 +45,12 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/agentconnectionbatcher"
"github.com/coder/coder/v2/coderd/aiseats"
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
"github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/awsidentity"
"github.com/coder/coder/v2/coderd/boundaryusage"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
@@ -64,7 +62,6 @@ import (
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/files"
"github.com/coder/coder/v2/coderd/gitsshkey"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/healthcheck"
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
"github.com/coder/coder/v2/coderd/httpapi"
@@ -95,6 +92,8 @@ import (
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/gitsync"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
@@ -251,8 +250,7 @@ type Options struct {
UpdateAgentMetrics func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
StatsBatcher workspacestats.Batcher
MetadataBatcherOptions []metadatabatcher.Option
ConnectionBatcherOptions []agentconnectionbatcher.Option
MetadataBatcherOptions []metadatabatcher.Option
ProvisionerdServerMetrics *provisionerdserver.Metrics
WorkspaceBuilderMetrics *wsbuilder.Metrics
@@ -769,43 +767,45 @@ func New(options *Options) *API {
}
api.agentProvider = stn
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
if maxChatsPerAcquire > math.MaxInt32 {
maxChatsPerAcquire = math.MaxInt32
}
if maxChatsPerAcquire < math.MinInt32 {
maxChatsPerAcquire = math.MinInt32
}
{ // Experimental: agents — chat daemon and git sync worker initialization.
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
if maxChatsPerAcquire > math.MaxInt32 {
maxChatsPerAcquire = math.MaxInt32
}
if maxChatsPerAcquire < math.MinInt32 {
maxChatsPerAcquire = math.MinInt32
}
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chatd"),
Database: options.Database,
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
UsageTracker: options.WorkspaceUsageTracker,
})
gitSyncLogger := options.Logger.Named("gitsync")
refresher := gitsync.NewRefresher(
api.resolveGitProvider,
api.resolveChatGitAccessToken,
gitSyncLogger.Named("refresher"),
quartz.NewReal(),
)
api.gitSyncWorker = gitsync.NewWorker(options.Database,
refresher,
api.chatDaemon.PublishDiffStatusChange,
quartz.NewReal(),
gitSyncLogger,
)
// nolint:gocritic // chat diff worker needs to be able to CRUD chats.
go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx))
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chatd"),
Database: options.Database,
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
UsageTracker: options.WorkspaceUsageTracker,
})
gitSyncLogger := options.Logger.Named("gitsync")
refresher := gitsync.NewRefresher(
api.resolveGitProvider,
api.resolveChatGitAccessToken,
gitSyncLogger.Named("refresher"),
quartz.NewReal(),
)
api.gitSyncWorker = gitsync.NewWorker(options.Database,
refresher,
api.chatDaemon.PublishDiffStatusChange,
quartz.NewReal(),
gitSyncLogger,
)
// nolint:gocritic // chat diff worker needs to be able to CRUD chats.
go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx))
}
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(stn)
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
@@ -860,17 +860,6 @@ func New(options *Options) *API {
api.Logger.Fatal(context.Background(), "failed to initialize metadata batcher", slog.Error(err))
}
// Initialize the connection batcher for batching agent heartbeat writes.
connBatcherOpts := []agentconnectionbatcher.Option{
agentconnectionbatcher.WithLogger(options.Logger.Named("connection_batcher")),
}
connBatcherOpts = append(connBatcherOpts, options.ConnectionBatcherOptions...)
api.connectionBatcher = agentconnectionbatcher.New(
api.ctx,
options.Database,
connBatcherOpts...,
)
workspaceAppsLogger := options.Logger.Named("workspaceapps")
if options.WorkspaceAppsStatsCollectorOptions.Logger == nil {
named := workspaceAppsLogger.Named("stats_collector")
@@ -1159,6 +1148,7 @@ func New(options *Options) *API {
})
})
})
// Experimental(agents): chat API routes gated by ExperimentAgents.
r.Route("/chats", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
@@ -1190,6 +1180,9 @@ func New(options *Options) *API {
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
r.Get("/user-prompt", api.getUserChatCustomPrompt)
r.Put("/user-prompt", api.putUserChatCustomPrompt)
r.Get("/user-compaction-thresholds", api.getUserChatCompactionThresholds)
r.Put("/user-compaction-thresholds/{modelConfig}", api.putUserChatCompactionThreshold)
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
})
@@ -1530,7 +1523,6 @@ func New(options *Options) *API {
r.Post("/", api.postUser)
r.Get("/", api.users)
r.Post("/logout", api.postLogout)
r.Post("/me/session/token-to-cookie", api.postSessionTokenCookie)
r.Get("/oidc-claims", api.userOIDCClaims)
// These routes query information about site wide roles.
r.Route("/roles", func(r chi.Router) {
@@ -2093,21 +2085,19 @@ type API struct {
healthCheckProgress healthcheck.Progress
statsReporter *workspacestats.Reporter
metadataBatcher *metadatabatcher.Batcher
connectionBatcher *agentconnectionbatcher.Batcher
metadataBatcher *metadatabatcher.Batcher
lifecycleMetrics *agentapi.LifecycleMetrics
Acquirer *provisionerdserver.Acquirer
// dbRolluper rolls up template usage stats from raw agent and app
// stats. This is used to provide insights in the WebUI.
dbRolluper *dbrollup.Rolluper
// chatDaemon handles background processing of pending chats.
// Experimental(agents): chatDaemon handles background processing of pending chats.
chatDaemon *chatd.Server
// Experimental(agents): gitSyncWorker refreshes stale chat diff statuses in the background.
gitSyncWorker *gitsync.Worker
// AISeatTracker records AI seat usage.
AISeatTracker aiseats.SeatTracker
// gitSyncWorker refreshes stale chat diff statuses in the
// background.
gitSyncWorker *gitsync.Worker
// ProfileCollector abstracts the runtime/pprof and runtime/trace
// calls used by the /debug/profile endpoint. Tests override this
@@ -2175,9 +2165,6 @@ func (api *API) Close() error {
if api.metadataBatcher != nil {
api.metadataBatcher.Close()
}
if api.connectionBatcher != nil {
api.connectionBatcher.Close()
}
_ = api.NetworkTelemetryBatcher.Close()
_ = api.OIDCConvertKeyCache.Close()
_ = api.AppSigningKeyCache.Close()
-344
View File
@@ -1,344 +0,0 @@
package connectionlogbatcher
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/quartz"
)
const (
// defaultBatchSize is the maximum number of connection log entries
// to batch before forcing a flush.
defaultBatchSize = 500
// defaultChannelBufferMultiplier is the multiplier for the channel
// buffer size relative to the batch size. A 5x multiplier provides
// significant headroom for bursts while the batch is being flushed.
defaultChannelBufferMultiplier = 5
// defaultFlushInterval is how frequently to flush batched connection
// log entries to the database. 1 second keeps audit logs near
// real-time.
defaultFlushInterval = time.Second
// finalFlushTimeout is the timeout for the final flush when the
// batcher is shutting down.
finalFlushTimeout = 15 * time.Second
)
// Batcher accumulates connection log upserts and periodically flushes
// them to the database in a single batch query. This reduces per-event
// database write pressure from O(n) queries to O(1).
type Batcher struct {
store database.Store
log slog.Logger
itemCh chan database.UpsertConnectionLogParams
batch []database.UpsertConnectionLogParams
maxBatchSize int
clock quartz.Clock
timer *quartz.Timer
interval time.Duration
ctx context.Context
cancel context.CancelFunc
done chan struct{}
}
// Option is a functional option for configuring a Batcher.
type Option func(b *Batcher)
// WithBatchSize sets the maximum number of entries to accumulate before
// forcing a flush.
func WithBatchSize(size int) Option {
return func(b *Batcher) {
b.maxBatchSize = size
}
}
// WithInterval sets how frequently the batcher flushes to the database.
func WithInterval(d time.Duration) Option {
return func(b *Batcher) {
b.interval = d
}
}
// WithLogger sets the logger for the batcher.
func WithLogger(log slog.Logger) Option {
return func(b *Batcher) {
b.log = log
}
}
// WithClock sets the clock for the batcher, useful for testing.
func WithClock(clock quartz.Clock) Option {
return func(b *Batcher) {
b.clock = clock
}
}
// New creates a new Batcher and starts its background processing loop.
// The provided context controls the lifetime of the batcher.
func New(ctx context.Context, store database.Store, opts ...Option) *Batcher {
b := &Batcher{
store: store,
done: make(chan struct{}),
log: slog.Logger{},
clock: quartz.NewReal(),
}
for _, opt := range opts {
opt(b)
}
if b.interval == 0 {
b.interval = defaultFlushInterval
}
if b.maxBatchSize == 0 {
b.maxBatchSize = defaultBatchSize
}
b.timer = b.clock.NewTimer(b.interval)
channelSize := b.maxBatchSize * defaultChannelBufferMultiplier
b.itemCh = make(chan database.UpsertConnectionLogParams, channelSize)
b.batch = make([]database.UpsertConnectionLogParams, 0, b.maxBatchSize)
b.ctx, b.cancel = context.WithCancel(ctx)
go func() {
b.run(b.ctx)
close(b.done)
}()
return b
}
// Close cancels the batcher context and waits for the final flush to
// complete.
func (b *Batcher) Close() {
b.cancel()
if b.timer != nil {
b.timer.Stop()
}
<-b.done
}
// Add enqueues a connection log upsert for batching. If the internal
// channel is full, the entry is dropped and a warning is logged.
func (b *Batcher) Add(item database.UpsertConnectionLogParams) {
select {
case b.itemCh <- item:
default:
b.log.Warn(context.Background(), "connection log batcher channel full, dropping entry",
slog.F("connection_id", item.ConnectionID),
)
}
}
func (b *Batcher) run(ctx context.Context) {
//nolint:gocritic // System-level batch operation for connection logs.
authCtx := dbauthz.AsConnectionLogger(ctx)
for {
select {
case item := <-b.itemCh:
b.batch = append(b.batch, item)
if len(b.batch) >= b.maxBatchSize {
b.flush(authCtx)
b.timer.Reset(b.interval, "connectionLogBatcher", "capacityFlush")
}
case <-b.timer.C:
b.flush(authCtx)
b.timer.Reset(b.interval, "connectionLogBatcher", "scheduledFlush")
case <-ctx.Done():
b.log.Debug(ctx, "context done, flushing before exit")
ctxTimeout, cancel := context.WithTimeout(context.Background(), finalFlushTimeout)
defer cancel() //nolint:revive // Returning after this.
//nolint:gocritic // System-level batch operation for connection logs.
b.flush(dbauthz.AsConnectionLogger(ctxTimeout))
return
}
}
}
// conflictKey represents the unique constraint columns used by
// the upsert query. Entries sharing the same key cannot appear
// in a single INSERT … ON CONFLICT DO UPDATE statement.
type conflictKey struct {
ConnectionID uuid.UUID
WorkspaceID uuid.UUID
AgentName string
}
func (b *Batcher) flush(ctx context.Context) {
count := len(b.batch)
if count == 0 {
return
}
b.log.Debug(ctx, "flushing connection log batch", slog.F("count", count))
// Deduplicate by conflict key so PostgreSQL never sees the
// same row twice in one INSERT … ON CONFLICT DO UPDATE.
// Entries with a NULL connection_id (web events) are exempt
// because NULL != NULL in SQL unique constraints.
deduped := make(map[conflictKey]database.UpsertConnectionLogParams, count)
var nullConnIDEntries []database.UpsertConnectionLogParams
for _, item := range b.batch {
if !item.ConnectionID.Valid {
nullConnIDEntries = append(nullConnIDEntries, item)
continue
}
key := conflictKey{
ConnectionID: item.ConnectionID.UUID,
WorkspaceID: item.WorkspaceID,
AgentName: item.AgentName,
}
existing, ok := deduped[key]
if !ok {
deduped[key] = item
continue
}
// Prefer disconnect over connect (superset of info).
// If same status, prefer the later event.
if item.ConnectionStatus == database.ConnectionStatusDisconnected &&
existing.ConnectionStatus != database.ConnectionStatusDisconnected {
deduped[key] = item
} else if item.Time.After(existing.Time) {
deduped[key] = item
}
}
// Rebuild batch from deduplicated entries.
items := make([]database.UpsertConnectionLogParams, 0, len(deduped)+len(nullConnIDEntries))
for _, item := range deduped {
items = append(items, item)
}
items = append(items, nullConnIDEntries...)
dedupedCount := len(items)
if dedupedCount < count {
b.log.Debug(ctx, "deduplicated connection log batch",
slog.F("original", count),
slog.F("deduped", dedupedCount),
)
}
var (
ids = make([]uuid.UUID, 0, dedupedCount)
connectTime = make([]time.Time, 0, dedupedCount)
organizationID = make([]uuid.UUID, 0, dedupedCount)
workspaceOwnerID = make([]uuid.UUID, 0, dedupedCount)
workspaceID = make([]uuid.UUID, 0, dedupedCount)
workspaceName = make([]string, 0, dedupedCount)
agentName = make([]string, 0, dedupedCount)
connType = make([]database.ConnectionType, 0, dedupedCount)
code = make([]int32, 0, dedupedCount)
ip = make([]pqtype.Inet, 0, dedupedCount)
userAgent = make([]string, 0, dedupedCount)
userID = make([]uuid.UUID, 0, dedupedCount)
slugOrPort = make([]string, 0, dedupedCount)
connectionID = make([]uuid.UUID, 0, dedupedCount)
disconnectReason = make([]string, 0, dedupedCount)
disconnectTime = make([]time.Time, 0, dedupedCount)
)
for _, item := range items {
ids = append(ids, item.ID)
connectTime = append(connectTime, item.Time)
organizationID = append(organizationID, item.OrganizationID)
workspaceOwnerID = append(workspaceOwnerID, item.WorkspaceOwnerID)
workspaceID = append(workspaceID, item.WorkspaceID)
workspaceName = append(workspaceName, item.WorkspaceName)
agentName = append(agentName, item.AgentName)
connType = append(connType, item.Type)
code = append(code, nullInt32ToInt32(item.Code))
ip = append(ip, item.Ip)
userAgent = append(userAgent, nullStringToString(item.UserAgent))
userID = append(userID, nullUUIDToUUID(item.UserID))
slugOrPort = append(slugOrPort, nullStringToString(item.SlugOrPort))
connectionID = append(connectionID, nullUUIDToUUID(item.ConnectionID))
disconnectReason = append(disconnectReason, nullStringToString(item.DisconnectReason))
// Pre-compute disconnect_time: if status is "disconnected",
// use the event time; otherwise use zero time (epoch) which
// the SQL CASE will treat as no disconnect.
if item.ConnectionStatus == database.ConnectionStatusDisconnected {
disconnectTime = append(disconnectTime, item.Time)
} else {
disconnectTime = append(disconnectTime, time.Time{})
}
}
// Clear batch before the DB call. Losing a batch of connection
// log entries is acceptable; the next event will be recorded.
b.batch = make([]database.UpsertConnectionLogParams, 0, b.maxBatchSize)
err := b.store.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: ids,
ConnectTime: connectTime,
OrganizationID: organizationID,
WorkspaceOwnerID: workspaceOwnerID,
WorkspaceID: workspaceID,
WorkspaceName: workspaceName,
AgentName: agentName,
Type: connType,
Code: code,
Ip: ip,
UserAgent: userAgent,
UserID: userID,
SlugOrPort: slugOrPort,
ConnectionID: connectionID,
DisconnectReason: disconnectReason,
DisconnectTime: disconnectTime,
})
if err != nil {
if database.IsQueryCanceledError(err) {
b.log.Debug(ctx, "query canceled, skipping connection log batch update")
return
}
b.log.Error(ctx, "failed to batch upsert connection logs", slog.Error(err))
return
}
b.log.Debug(ctx, "connection log batch flush complete", slog.F("count", count))
}
// nullStringToString converts a sql.NullString to a string. When the
// NullString is not valid, an empty string is returned.
func nullStringToString(ns sql.NullString) string {
if ns.Valid {
return ns.String
}
return ""
}
// nullInt32ToInt32 converts a sql.NullInt32 to an int32. When the
// NullInt32 is not valid, zero is returned.
func nullInt32ToInt32(ni sql.NullInt32) int32 {
if ni.Valid {
return ni.Int32
}
return 0
}
// nullUUIDToUUID converts a uuid.NullUUID to a uuid.UUID. When the
// NullUUID is not valid, uuid.Nil is returned.
func nullUUIDToUUID(nu uuid.NullUUID) uuid.UUID {
if nu.Valid {
return nu.UUID
}
return uuid.Nil
}
+41 -1
View File
@@ -19,7 +19,6 @@ import (
"tailscale.com/tailcfg"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/rbac"
@@ -28,6 +27,7 @@ import (
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/tailnet"
@@ -223,6 +223,7 @@ func UserFromGroupMember(member database.GroupMember) database.User {
QuietHoursSchedule: member.UserQuietHoursSchedule,
Name: member.UserName,
GithubComUserID: member.UserGithubComUserID,
IsServiceAccount: member.UserIsServiceAccount,
}
}
@@ -251,6 +252,7 @@ func UserFromGroupMemberRow(member database.GetGroupMembersByGroupIDPaginatedRow
QuietHoursSchedule: member.UserQuietHoursSchedule,
Name: member.UserName,
GithubComUserID: member.UserGithubComUserID,
IsServiceAccount: member.UserIsServiceAccount,
}
}
@@ -1019,6 +1021,44 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
return intc
}
func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSession {
session := codersdk.AIBridgeSession{
ID: row.SessionID,
Initiator: MinimalUserFromVisibleUser(database.VisibleUser{
ID: row.UserID,
Username: row.UserUsername,
Name: row.UserName,
AvatarURL: row.UserAvatarUrl,
}),
Providers: row.Providers,
Models: row.Models,
Metadata: jsonOrEmptyMap(pqtype.NullRawMessage{RawMessage: row.Metadata, Valid: len(row.Metadata) > 0}),
StartedAt: row.StartedAt,
Threads: row.Threads,
TokenUsageSummary: codersdk.AIBridgeSessionTokenUsageSummary{
InputTokens: row.InputTokens,
OutputTokens: row.OutputTokens,
},
}
// Ensure non-nil slices for JSON serialization.
if session.Providers == nil {
session.Providers = []string{}
}
if session.Models == nil {
session.Models = []string{}
}
if row.Client != "" {
session.Client = &row.Client
}
if !row.EndedAt.IsZero() {
session.EndedAt = &row.EndedAt
}
if row.LastPrompt != "" {
session.LastPrompt = &row.LastPrompt
}
return session
}
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
return codersdk.AIBridgeTokenUsage{
ID: usage.ID,
+71 -39
View File
@@ -1602,15 +1602,6 @@ func (q *querier) BackoffChatDiffStatus(ctx context.Context, arg database.Backof
return q.db.BackoffChatDiffStatus(ctx, arg)
}
func (q *querier) BatchUpdateWorkspaceAgentConnections(ctx context.Context, arg database.BatchUpdateWorkspaceAgentConnectionsParams) error {
// Could be any workspace agent and checking auth to each workspace
// agent is overkill for the purpose of this function.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceWorkspace.All()); err != nil {
return err
}
return q.db.BatchUpdateWorkspaceAgentConnections(ctx, arg)
}
func (q *querier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
// Could be any workspace agent and checking auth to each workspace agent is overkill for
// the purpose of this function.
@@ -1636,13 +1627,6 @@ func (q *querier) BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg datab
return q.db.BatchUpdateWorkspaceNextStartAt(ctx, arg)
}
func (q *querier) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return err
}
return q.db.BatchUpsertConnectionLogs(ctx, arg)
}
func (q *querier) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceNotificationMessage); err != nil {
return 0, err
@@ -1725,6 +1709,14 @@ func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.C
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prep)
}
func (q *querier) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prep)
}
func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
// Shortcut if the user is an owner. The SQL filter is noticeable,
// and this is an easy win for owners. Which is the common case.
@@ -2134,6 +2126,17 @@ func (q *querier) DeleteTask(ctx context.Context, arg database.DeleteTaskParams)
return q.db.DeleteTask(ctx, arg)
}
func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return err
}
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
}
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// First get the secret to check ownership
secret, err := q.GetUserSecret(ctx, id)
@@ -3584,13 +3587,6 @@ func (q *querier) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.U
return q.db.GetTailnetTunnelPeerBindings(ctx, srcID)
}
func (q *querier) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetTailnetTunnelPeerBindingsBatch(ctx, ids)
}
func (q *querier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
@@ -3598,13 +3594,6 @@ func (q *querier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID)
return q.db.GetTailnetTunnelPeerIDs(ctx, srcID)
}
func (q *querier) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetTailnetTunnelPeerIDsBatch(ctx, ids)
}
func (q *querier) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, error) {
return fetch(q.log, q.auth, q.db.GetTaskByID)(ctx, id)
}
@@ -3951,6 +3940,17 @@ func (q *querier) GetUserByID(ctx context.Context, id uuid.UUID) (database.User,
return fetch(q.log, q.auth, q.db.GetUserByID)(ctx, id)
}
func (q *querier) GetUserChatCompactionThreshold(ctx context.Context, arg database.GetUserChatCompactionThresholdParams) (string, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return "", err
}
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
return "", err
}
return q.db.GetUserChatCompactionThreshold(ctx, arg)
}
func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
@@ -5325,10 +5325,16 @@ func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBri
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
}
func (q *querier) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prep)
}
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return nil, err
}
@@ -5336,9 +5342,7 @@ func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context,
}
func (q *querier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeToolUsage, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return nil, err
}
@@ -5346,9 +5350,7 @@ func (q *querier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, i
}
func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeUserPrompt, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return nil, err
}
@@ -5382,6 +5384,17 @@ func (q *querier) ListTasks(ctx context.Context, arg database.ListTasksParams) (
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListTasks)(ctx, arg)
}
func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]database.UserConfig, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
return nil, err
}
return q.db.ListUserChatCompactionThresholds(ctx, userID)
}
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
@@ -6242,6 +6255,17 @@ func (q *querier) UpdateUsageEventsPostPublish(ctx context.Context, arg database
return q.db.UpdateUsageEventsPostPublish(ctx, arg)
}
func (q *querier) UpdateUserChatCompactionThreshold(ctx context.Context, arg database.UpdateUserChatCompactionThresholdParams) (database.UserConfig, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return database.UserConfig{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return database.UserConfig{}, err
}
return q.db.UpdateUserChatCompactionThreshold(ctx, arg)
}
func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.UpdateUserChatCustomPromptParams) (database.UserConfig, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
@@ -7114,6 +7138,14 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
return q.ListAIBridgeModels(ctx, arg)
}
func (q *querier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
}
func (q *querier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
}
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
return q.GetChats(ctx, arg)
}
+143 -194
View File
@@ -2278,6 +2278,35 @@ func (s *MethodTestSuite) TestUser() {
dbm.EXPECT().UpdateUserChatCustomPrompt(gomock.Any(), arg).Return(uc, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(uc)
}))
s.Run("ListUserChatCompactionThresholds", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
uc := database.UserConfig{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001", Value: "75"}
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().ListUserChatCompactionThresholds(gomock.Any(), u.ID).Return([]database.UserConfig{uc}, nil).AnyTimes()
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserConfig{uc})
}))
s.Run("GetUserChatCompactionThreshold", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.GetUserChatCompactionThresholdParams{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001"}
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), arg).Return("75", nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionReadPersonal).Returns("75")
}))
s.Run("UpdateUserChatCompactionThreshold", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
uc := database.UserConfig{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001", Value: "75"}
arg := database.UpdateUserChatCompactionThresholdParams{UserID: u.ID, Key: uc.Key, ThresholdPercent: 75}
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().UpdateUserChatCompactionThreshold(gomock.Any(), arg).Return(uc, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(uc)
}))
s.Run("DeleteUserChatCompactionThreshold", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.DeleteUserChatCompactionThresholdParams{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001"}
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().DeleteUserChatCompactionThreshold(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal)
}))
s.Run("UpdateUserTaskNotificationAlertDismissed", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
userConfig := database.UserConfig{UserID: user.ID, Key: "task_notification_alert_dismissed", Value: "false"}
@@ -2623,20 +2652,6 @@ func (s *MethodTestSuite) TestWorkspace() {
dbm.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), arg).Return([]database.WorkspaceAgentMetadatum{dt}, nil).AnyTimes()
check.Args(arg).Asserts(w, policy.ActionRead).Returns([]database.WorkspaceAgentMetadatum{dt})
}))
s.Run("BatchUpdateWorkspaceAgentConnections", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
now := dbtime.Now()
arg := database.BatchUpdateWorkspaceAgentConnectionsParams{
ID: []uuid.UUID{agt.ID},
FirstConnectedAt: []time.Time{now},
LastConnectedAt: []time.Time{now},
LastConnectedReplicaID: []uuid.UUID{uuid.New()},
DisconnectedAt: []time.Time{{}},
UpdatedAt: []time.Time{now},
}
dbm.EXPECT().BatchUpdateWorkspaceAgentConnections(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceWorkspace.All(), policy.ActionUpdate).Returns()
}))
s.Run("BatchUpdateWorkspaceAgentMetadata", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
arg := database.BatchUpdateWorkspaceAgentMetadataParams{
@@ -3174,109 +3189,59 @@ func (s *MethodTestSuite) TestWorkspace() {
}
func (s *MethodTestSuite) TestWorkspacePortSharing() {
s.Run("UpsertWorkspaceAgentPortShare", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
org := dbgen.Organization(s.T(), db, database.Organization{})
tpl := dbgen.Template(s.T(), db, database.Template{
OrganizationID: org.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: org.ID,
TemplateID: tpl.ID,
})
ps := dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
//nolint:gosimple // casting is not a simplification
check.Args(database.UpsertWorkspaceAgentPortShareParams{
s.Run("UpsertWorkspaceAgentPortShare", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
ps := testutil.Fake(s.T(), faker, database.WorkspaceAgentPortShare{})
ps.WorkspaceID = ws.ID
arg := database.UpsertWorkspaceAgentPortShareParams(ps)
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
dbm.EXPECT().UpsertWorkspaceAgentPortShare(gomock.Any(), arg).Return(ps, nil).AnyTimes()
check.Args(arg).Asserts(ws, policy.ActionUpdate).Returns(ps)
}))
s.Run("GetWorkspaceAgentPortShare", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
ps := testutil.Fake(s.T(), faker, database.WorkspaceAgentPortShare{})
ps.WorkspaceID = ws.ID
arg := database.GetWorkspaceAgentPortShareParams{
WorkspaceID: ps.WorkspaceID,
AgentName: ps.AgentName,
Port: ps.Port,
ShareLevel: ps.ShareLevel,
Protocol: ps.Protocol,
}).Asserts(ws, policy.ActionUpdate).Returns(ps)
}
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
dbm.EXPECT().GetWorkspaceAgentPortShare(gomock.Any(), arg).Return(ps, nil).AnyTimes()
check.Args(arg).Asserts(ws, policy.ActionRead).Returns(ps)
}))
s.Run("GetWorkspaceAgentPortShare", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
org := dbgen.Organization(s.T(), db, database.Organization{})
tpl := dbgen.Template(s.T(), db, database.Template{
OrganizationID: org.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: org.ID,
TemplateID: tpl.ID,
})
ps := dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
check.Args(database.GetWorkspaceAgentPortShareParams{
WorkspaceID: ps.WorkspaceID,
AgentName: ps.AgentName,
Port: ps.Port,
}).Asserts(ws, policy.ActionRead).Returns(ps)
}))
s.Run("ListWorkspaceAgentPortShares", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
org := dbgen.Organization(s.T(), db, database.Organization{})
tpl := dbgen.Template(s.T(), db, database.Template{
OrganizationID: org.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: org.ID,
TemplateID: tpl.ID,
})
ps := dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
s.Run("ListWorkspaceAgentPortShares", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
ps := testutil.Fake(s.T(), faker, database.WorkspaceAgentPortShare{})
ps.WorkspaceID = ws.ID
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
dbm.EXPECT().ListWorkspaceAgentPortShares(gomock.Any(), ws.ID).Return([]database.WorkspaceAgentPortShare{ps}, nil).AnyTimes()
check.Args(ws.ID).Asserts(ws, policy.ActionRead).Returns([]database.WorkspaceAgentPortShare{ps})
}))
s.Run("DeleteWorkspaceAgentPortShare", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
org := dbgen.Organization(s.T(), db, database.Organization{})
tpl := dbgen.Template(s.T(), db, database.Template{
OrganizationID: org.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: org.ID,
TemplateID: tpl.ID,
})
ps := dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
check.Args(database.DeleteWorkspaceAgentPortShareParams{
s.Run("DeleteWorkspaceAgentPortShare", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
ps := testutil.Fake(s.T(), faker, database.WorkspaceAgentPortShare{})
ps.WorkspaceID = ws.ID
arg := database.DeleteWorkspaceAgentPortShareParams{
WorkspaceID: ps.WorkspaceID,
AgentName: ps.AgentName,
Port: ps.Port,
}).Asserts(ws, policy.ActionUpdate).Returns()
}
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
dbm.EXPECT().DeleteWorkspaceAgentPortShare(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(ws, policy.ActionUpdate).Returns()
}))
s.Run("DeleteWorkspaceAgentPortSharesByTemplate", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
org := dbgen.Organization(s.T(), db, database.Organization{})
tpl := dbgen.Template(s.T(), db, database.Template{
OrganizationID: org.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: org.ID,
TemplateID: tpl.ID,
})
_ = dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
s.Run("DeleteWorkspaceAgentPortSharesByTemplate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
tpl := testutil.Fake(s.T(), faker, database.Template{})
dbm.EXPECT().GetTemplateByID(gomock.Any(), tpl.ID).Return(tpl, nil).AnyTimes()
dbm.EXPECT().DeleteWorkspaceAgentPortSharesByTemplate(gomock.Any(), tpl.ID).Return(nil).AnyTimes()
check.Args(tpl.ID).Asserts(tpl, policy.ActionUpdate).Returns()
}))
s.Run("ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
org := dbgen.Organization(s.T(), db, database.Organization{})
tpl := dbgen.Template(s.T(), db, database.Template{
OrganizationID: org.ID,
CreatedBy: u.ID,
})
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
OwnerID: u.ID,
OrganizationID: org.ID,
TemplateID: tpl.ID,
})
_ = dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
s.Run("ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
tpl := testutil.Fake(s.T(), faker, database.Template{})
dbm.EXPECT().GetTemplateByID(gomock.Any(), tpl.ID).Return(tpl, nil).AnyTimes()
dbm.EXPECT().ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(gomock.Any(), tpl.ID).Return(nil).AnyTimes()
check.Args(tpl.ID).Asserts(tpl, policy.ActionUpdate).Returns()
}))
}
@@ -4993,113 +4958,69 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppTokens() {
}
func (s *MethodTestSuite) TestResourcesMonitor() {
createAgent := func(t *testing.T, db database.Store) (database.WorkspaceAgent, database.WorkspaceTable) {
t.Helper()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
OrganizationID: o.ID,
CreatedBy: u.ID,
})
w := dbgen.Workspace(t, db, database.WorkspaceTable{
TemplateID: tpl.ID,
OrganizationID: o.ID,
OwnerID: u.ID,
})
j := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
b := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
JobID: j.ID,
WorkspaceID: w.ID,
TemplateVersionID: tv.ID,
})
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{JobID: b.JobID})
agt := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: res.ID})
return agt, w
}
s.Run("InsertMemoryResourceMonitor", s.Subtest(func(db database.Store, check *expects) {
agt, _ := createAgent(s.T(), db)
check.Args(database.InsertMemoryResourceMonitorParams{
AgentID: agt.ID,
s.Run("InsertMemoryResourceMonitor", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertMemoryResourceMonitorParams{
AgentID: uuid.New(),
State: database.WorkspaceAgentMonitorStateOK,
}).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionCreate)
}
dbm.EXPECT().InsertMemoryResourceMonitor(gomock.Any(), arg).Return(database.WorkspaceAgentMemoryResourceMonitor{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionCreate)
}))
s.Run("InsertVolumeResourceMonitor", s.Subtest(func(db database.Store, check *expects) {
agt, _ := createAgent(s.T(), db)
check.Args(database.InsertVolumeResourceMonitorParams{
AgentID: agt.ID,
s.Run("InsertVolumeResourceMonitor", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertVolumeResourceMonitorParams{
AgentID: uuid.New(),
State: database.WorkspaceAgentMonitorStateOK,
}).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionCreate)
}
dbm.EXPECT().InsertVolumeResourceMonitor(gomock.Any(), arg).Return(database.WorkspaceAgentVolumeResourceMonitor{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionCreate)
}))
s.Run("UpdateMemoryResourceMonitor", s.Subtest(func(db database.Store, check *expects) {
agt, _ := createAgent(s.T(), db)
check.Args(database.UpdateMemoryResourceMonitorParams{
AgentID: agt.ID,
s.Run("UpdateMemoryResourceMonitor", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.UpdateMemoryResourceMonitorParams{
AgentID: uuid.New(),
State: database.WorkspaceAgentMonitorStateOK,
}).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionUpdate)
}
dbm.EXPECT().UpdateMemoryResourceMonitor(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionUpdate)
}))
s.Run("UpdateVolumeResourceMonitor", s.Subtest(func(db database.Store, check *expects) {
agt, _ := createAgent(s.T(), db)
check.Args(database.UpdateVolumeResourceMonitorParams{
AgentID: agt.ID,
s.Run("UpdateVolumeResourceMonitor", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.UpdateVolumeResourceMonitorParams{
AgentID: uuid.New(),
State: database.WorkspaceAgentMonitorStateOK,
}).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionUpdate)
}
dbm.EXPECT().UpdateVolumeResourceMonitor(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionUpdate)
}))
s.Run("FetchMemoryResourceMonitorsUpdatedAfter", s.Subtest(func(db database.Store, check *expects) {
s.Run("FetchMemoryResourceMonitorsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
dbm.EXPECT().FetchMemoryResourceMonitorsUpdatedAfter(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
check.Args(dbtime.Now()).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionRead)
}))
s.Run("FetchVolumesResourceMonitorsUpdatedAfter", s.Subtest(func(db database.Store, check *expects) {
s.Run("FetchVolumesResourceMonitorsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
dbm.EXPECT().FetchVolumesResourceMonitorsUpdatedAfter(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
check.Args(dbtime.Now()).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionRead)
}))
s.Run("FetchMemoryResourceMonitorsByAgentID", s.Subtest(func(db database.Store, check *expects) {
agt, w := createAgent(s.T(), db)
dbgen.WorkspaceAgentMemoryResourceMonitor(s.T(), db, database.WorkspaceAgentMemoryResourceMonitor{
AgentID: agt.ID,
Enabled: true,
Threshold: 80,
CreatedAt: dbtime.Now(),
})
monitor, err := db.FetchMemoryResourceMonitorsByAgentID(context.Background(), agt.ID)
require.NoError(s.T(), err)
s.Run("FetchMemoryResourceMonitorsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
monitor := testutil.Fake(s.T(), faker, database.WorkspaceAgentMemoryResourceMonitor{})
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
dbm.EXPECT().FetchMemoryResourceMonitorsByAgentID(gomock.Any(), agt.ID).Return(monitor, nil).AnyTimes()
check.Args(agt.ID).Asserts(w, policy.ActionRead).Returns(monitor)
}))
s.Run("FetchVolumesResourceMonitorsByAgentID", s.Subtest(func(db database.Store, check *expects) {
agt, w := createAgent(s.T(), db)
dbgen.WorkspaceAgentVolumeResourceMonitor(s.T(), db, database.WorkspaceAgentVolumeResourceMonitor{
AgentID: agt.ID,
Path: "/var/lib",
Enabled: true,
Threshold: 80,
CreatedAt: dbtime.Now(),
})
monitors, err := db.FetchVolumesResourceMonitorsByAgentID(context.Background(), agt.ID)
require.NoError(s.T(), err)
s.Run("FetchVolumesResourceMonitorsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
monitors := []database.WorkspaceAgentVolumeResourceMonitor{
testutil.Fake(s.T(), faker, database.WorkspaceAgentVolumeResourceMonitor{}),
}
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
dbm.EXPECT().FetchVolumesResourceMonitorsByAgentID(gomock.Any(), agt.ID).Return(monitors, nil).AnyTimes()
check.Args(agt.ID).Asserts(w, policy.ActionRead).Returns(monitors)
}))
}
@@ -5499,22 +5420,50 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(params, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("ListAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.ListAIBridgeSessionsParams{}
db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("ListAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.ListAIBridgeSessionsParams{}
db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("CountAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.CountAIBridgeSessionsParams{}
db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("CountAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.CountAIBridgeSessionsParams{}
db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{{1}}
db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeTokenUsage{})
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeTokenUsage{})
}))
s.Run("ListAIBridgeUserPromptsByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{{1}}
db.EXPECT().ListAIBridgeUserPromptsByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeUserPrompt{}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeUserPrompt{})
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeUserPrompt{})
}))
s.Run("ListAIBridgeToolUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{{1}}
db.EXPECT().ListAIBridgeToolUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeToolUsage{}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
}))
s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
+13 -2
View File
@@ -4,9 +4,10 @@ import (
"context"
"encoding/gob"
"errors"
"flag"
"fmt"
"reflect"
"sort"
"slices"
"strings"
"testing"
@@ -90,6 +91,16 @@ func (s *MethodTestSuite) SetupSuite() {
// TearDownSuite asserts that all methods were called at least once.
func (s *MethodTestSuite) TearDownSuite() {
s.Run("Accounting", func() {
// testify/suite's -testify.m flag filters which suite methods
// run, but TearDownSuite still executes. Skip the Accounting
// check when filtering to avoid misleading "method never
// called" errors for every method that was filtered out.
if f := flag.Lookup("testify.m"); f != nil {
if f.Value.String() != "" {
s.T().Skip("Skipping Accounting check: -testify.m flag is set")
}
}
t := s.T()
notCalled := []string{}
for m, c := range s.methodAccounting {
@@ -97,7 +108,7 @@ func (s *MethodTestSuite) TearDownSuite() {
notCalled = append(notCalled, m)
}
}
sort.Strings(notCalled)
slices.Sort(notCalled)
for _, m := range notCalled {
t.Errorf("Method never called: %q", m)
}
+64 -32
View File
@@ -184,14 +184,6 @@ func (m queryMetricsStore) BackoffChatDiffStatus(ctx context.Context, arg databa
return r0
}
func (m queryMetricsStore) BatchUpdateWorkspaceAgentConnections(ctx context.Context, arg database.BatchUpdateWorkspaceAgentConnectionsParams) error {
start := time.Now()
r0 := m.s.BatchUpdateWorkspaceAgentConnections(ctx, arg)
m.queryLatencies.WithLabelValues("BatchUpdateWorkspaceAgentConnections").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpdateWorkspaceAgentConnections").Inc()
return r0
}
func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
start := time.Now()
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
@@ -216,14 +208,6 @@ func (m queryMetricsStore) BatchUpdateWorkspaceNextStartAt(ctx context.Context,
return r0
}
func (m queryMetricsStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
start := time.Now()
r0 := m.s.BatchUpsertConnectionLogs(ctx, arg)
m.queryLatencies.WithLabelValues("BatchUpsertConnectionLogs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpsertConnectionLogs").Inc()
return r0
}
func (m queryMetricsStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.BulkMarkNotificationMessagesFailed(ctx, arg)
@@ -296,6 +280,14 @@ func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg d
return r0, r1
}
func (m queryMetricsStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountAIBridgeSessions(ctx, arg)
m.queryLatencies.WithLabelValues("CountAIBridgeSessions").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAIBridgeSessions").Inc()
return r0, r1
}
func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountAuditLogs(ctx, arg)
@@ -696,6 +688,14 @@ func (m queryMetricsStore) DeleteTask(ctx context.Context, arg database.DeleteTa
return r0, r1
}
func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error {
start := time.Now()
r0 := m.s.DeleteUserChatCompactionThreshold(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteUserChatCompactionThreshold").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatCompactionThreshold").Inc()
return r0
}
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteUserSecret(ctx, id)
@@ -2176,14 +2176,6 @@ func (m queryMetricsStore) GetTailnetTunnelPeerBindings(ctx context.Context, src
return r0, r1
}
func (m queryMetricsStore) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
start := time.Now()
r0, r1 := m.s.GetTailnetTunnelPeerBindingsBatch(ctx, ids)
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerBindingsBatch").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerBindingsBatch").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTailnetTunnelPeerIDs(ctx, srcID)
@@ -2192,14 +2184,6 @@ func (m queryMetricsStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uu
return r0, r1
}
func (m queryMetricsStore) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
start := time.Now()
r0, r1 := m.s.GetTailnetTunnelPeerIDsBatch(ctx, ids)
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerIDsBatch").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerIDsBatch").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, error) {
start := time.Now()
r0, r1 := m.s.GetTaskByID(ctx, id)
@@ -2480,6 +2464,14 @@ func (m queryMetricsStore) GetUserByID(ctx context.Context, id uuid.UUID) (datab
return r0, r1
}
func (m queryMetricsStore) GetUserChatCompactionThreshold(ctx context.Context, arg database.GetUserChatCompactionThresholdParams) (string, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatCompactionThreshold(ctx, arg)
m.queryLatencies.WithLabelValues("GetUserChatCompactionThreshold").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatCompactionThreshold").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatCustomPrompt(ctx, userID)
@@ -3736,6 +3728,14 @@ func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeSessions(ctx, arg)
m.queryLatencies.WithLabelValues("ListAIBridgeSessions").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeSessions").Inc()
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
@@ -3800,6 +3800,14 @@ func (m queryMetricsStore) ListTasks(ctx context.Context, arg database.ListTasks
return r0, r1
}
func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]database.UserConfig, error) {
start := time.Now()
r0, r1 := m.s.ListUserChatCompactionThresholds(ctx, userID)
m.queryLatencies.WithLabelValues("ListUserChatCompactionThresholds").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserChatCompactionThresholds").Inc()
return r0, r1
}
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.ListUserSecrets(ctx, userID)
@@ -4392,6 +4400,14 @@ func (m queryMetricsStore) UpdateUsageEventsPostPublish(ctx context.Context, arg
return r0
}
func (m queryMetricsStore) UpdateUserChatCompactionThreshold(ctx context.Context, arg database.UpdateUserChatCompactionThresholdParams) (database.UserConfig, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserChatCompactionThreshold(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserChatCompactionThreshold").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatCompactionThreshold").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg database.UpdateUserChatCustomPromptParams) (database.UserConfig, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserChatCustomPrompt(ctx, arg)
@@ -5136,6 +5152,22 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
return r0, r1
}
func (m queryMetricsStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
start := time.Now()
r0, r1 := m.s.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeSessions").Inc()
return r0, r1
}
func (m queryMetricsStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("CountAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeSessions").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
+119 -58
View File
@@ -190,20 +190,6 @@ func (mr *MockStoreMockRecorder) BackoffChatDiffStatus(ctx, arg any) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackoffChatDiffStatus", reflect.TypeOf((*MockStore)(nil).BackoffChatDiffStatus), ctx, arg)
}
// BatchUpdateWorkspaceAgentConnections mocks base method.
func (m *MockStore) BatchUpdateWorkspaceAgentConnections(ctx context.Context, arg database.BatchUpdateWorkspaceAgentConnectionsParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchUpdateWorkspaceAgentConnections", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// BatchUpdateWorkspaceAgentConnections indicates an expected call of BatchUpdateWorkspaceAgentConnections.
func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceAgentConnections(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceAgentConnections", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceAgentConnections), ctx, arg)
}
// BatchUpdateWorkspaceAgentMetadata mocks base method.
func (m *MockStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
m.ctrl.T.Helper()
@@ -246,20 +232,6 @@ func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceNextStartAt(ctx, arg any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceNextStartAt", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceNextStartAt), ctx, arg)
}
// BatchUpsertConnectionLogs mocks base method.
func (m *MockStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchUpsertConnectionLogs", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// BatchUpsertConnectionLogs indicates an expected call of BatchUpsertConnectionLogs.
func (mr *MockStoreMockRecorder) BatchUpsertConnectionLogs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpsertConnectionLogs", reflect.TypeOf((*MockStore)(nil).BatchUpsertConnectionLogs), ctx, arg)
}
// BulkMarkNotificationMessagesFailed mocks base method.
func (m *MockStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
m.ctrl.T.Helper()
@@ -391,6 +363,21 @@ func (mr *MockStoreMockRecorder) CountAIBridgeInterceptions(ctx, arg any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeInterceptions), ctx, arg)
}
// CountAIBridgeSessions mocks base method.
func (m *MockStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountAIBridgeSessions", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountAIBridgeSessions indicates an expected call of CountAIBridgeSessions.
func (mr *MockStoreMockRecorder) CountAIBridgeSessions(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeSessions), ctx, arg)
}
// CountAuditLogs mocks base method.
func (m *MockStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
m.ctrl.T.Helper()
@@ -421,6 +408,21 @@ func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeInterceptions(ctx, arg,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeInterceptions), ctx, arg, prepared)
}
// CountAuthorizedAIBridgeSessions mocks base method.
func (m *MockStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountAuthorizedAIBridgeSessions", ctx, arg, prepared)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountAuthorizedAIBridgeSessions indicates an expected call of CountAuthorizedAIBridgeSessions.
func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeSessions), ctx, arg, prepared)
}
// CountAuthorizedAuditLogs mocks base method.
func (m *MockStore) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
m.ctrl.T.Helper()
@@ -1154,6 +1156,20 @@ func (mr *MockStoreMockRecorder) DeleteTask(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockStore)(nil).DeleteTask), ctx, arg)
}
// DeleteUserChatCompactionThreshold mocks base method.
func (m *MockStore) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteUserChatCompactionThreshold", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteUserChatCompactionThreshold indicates an expected call of DeleteUserChatCompactionThreshold.
func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
}
// DeleteUserSecret mocks base method.
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -4022,21 +4038,6 @@ func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindings(ctx, srcID any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindings", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindings), ctx, srcID)
}
// GetTailnetTunnelPeerBindingsBatch mocks base method.
func (m *MockStore) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerBindingsBatch", ctx, ids)
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerBindingsBatchRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTailnetTunnelPeerBindingsBatch indicates an expected call of GetTailnetTunnelPeerBindingsBatch.
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindingsBatch(ctx, ids any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindingsBatch", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindingsBatch), ctx, ids)
}
// GetTailnetTunnelPeerIDs mocks base method.
func (m *MockStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
m.ctrl.T.Helper()
@@ -4052,21 +4053,6 @@ func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerIDs(ctx, srcID any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerIDs", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerIDs), ctx, srcID)
}
// GetTailnetTunnelPeerIDsBatch mocks base method.
func (m *MockStore) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerIDsBatch", ctx, ids)
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerIDsBatchRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTailnetTunnelPeerIDsBatch indicates an expected call of GetTailnetTunnelPeerIDsBatch.
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerIDsBatch(ctx, ids any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerIDsBatch", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerIDsBatch), ctx, ids)
}
// GetTaskByID mocks base method.
func (m *MockStore) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, error) {
m.ctrl.T.Helper()
@@ -4622,6 +4608,21 @@ func (mr *MockStoreMockRecorder) GetUserByID(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserByID", reflect.TypeOf((*MockStore)(nil).GetUserByID), ctx, id)
}
// GetUserChatCompactionThreshold mocks base method.
func (m *MockStore) GetUserChatCompactionThreshold(ctx context.Context, arg database.GetUserChatCompactionThresholdParams) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserChatCompactionThreshold", ctx, arg)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserChatCompactionThreshold indicates an expected call of GetUserChatCompactionThreshold.
func (mr *MockStoreMockRecorder) GetUserChatCompactionThreshold(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).GetUserChatCompactionThreshold), ctx, arg)
}
// GetUserChatCustomPrompt mocks base method.
func (m *MockStore) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
m.ctrl.T.Helper()
@@ -6976,6 +6977,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg)
}
// ListAIBridgeSessions mocks base method.
func (m *MockStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAIBridgeSessions", ctx, arg)
ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAIBridgeSessions indicates an expected call of ListAIBridgeSessions.
func (mr *MockStoreMockRecorder) ListAIBridgeSessions(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeSessions), ctx, arg)
}
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
@@ -7051,6 +7067,21 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
}
// ListAuthorizedAIBridgeSessions mocks base method.
func (m *MockStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeSessions", ctx, arg, prepared)
ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAuthorizedAIBridgeSessions indicates an expected call of ListAuthorizedAIBridgeSessions.
func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeSessions), ctx, arg, prepared)
}
// ListChatUsageLimitGroupOverrides mocks base method.
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
m.ctrl.T.Helper()
@@ -7126,6 +7157,21 @@ func (mr *MockStoreMockRecorder) ListTasks(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockStore)(nil).ListTasks), ctx, arg)
}
// ListUserChatCompactionThresholds mocks base method.
func (m *MockStore) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]database.UserConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUserChatCompactionThresholds", ctx, userID)
ret0, _ := ret[0].([]database.UserConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListUserChatCompactionThresholds indicates an expected call of ListUserChatCompactionThresholds.
func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserChatCompactionThresholds", reflect.TypeOf((*MockStore)(nil).ListUserChatCompactionThresholds), ctx, userID)
}
// ListUserSecrets mocks base method.
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
m.ctrl.T.Helper()
@@ -8231,6 +8277,21 @@ func (mr *MockStoreMockRecorder) UpdateUsageEventsPostPublish(ctx, arg any) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUsageEventsPostPublish", reflect.TypeOf((*MockStore)(nil).UpdateUsageEventsPostPublish), ctx, arg)
}
// UpdateUserChatCompactionThreshold mocks base method.
func (m *MockStore) UpdateUserChatCompactionThreshold(ctx context.Context, arg database.UpdateUserChatCompactionThresholdParams) (database.UserConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserChatCompactionThreshold", ctx, arg)
ret0, _ := ret[0].(database.UserConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserChatCompactionThreshold indicates an expected call of UpdateUserChatCompactionThreshold.
func (mr *MockStoreMockRecorder) UpdateUserChatCompactionThreshold(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCompactionThreshold), ctx, arg)
}
// UpdateUserChatCustomPrompt mocks base method.
func (m *MockStore) UpdateUserChatCustomPrompt(ctx context.Context, arg database.UpdateUserChatCustomPromptParams) (database.UserConfig, error) {
m.ctrl.T.Helper()
+13 -4
View File
@@ -1099,7 +1099,8 @@ CREATE TABLE aibridge_interceptions (
client character varying(64) DEFAULT 'Unknown'::character varying,
thread_parent_id uuid,
thread_root_id uuid,
client_session_id character varying(256)
client_session_id character varying(256),
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL
);
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
@@ -1112,6 +1113,8 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception''s own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.';
CREATE TABLE aibridge_model_thoughts (
interception_id uuid NOT NULL,
content text NOT NULL,
@@ -1291,7 +1294,8 @@ CREATE TABLE chat_messages (
content_version smallint NOT NULL,
total_cost_micros bigint,
runtime_ms bigint,
deleted boolean DEFAULT false NOT NULL
deleted boolean DEFAULT false NOT NULL,
provider_response_id text
);
CREATE SEQUENCE chat_messages_id_seq
@@ -1619,6 +1623,7 @@ CREATE VIEW group_members_expanded AS
users.name AS user_name,
users.github_com_user_id AS user_github_com_user_id,
users.is_system AS user_is_system,
users.is_service_account AS user_is_service_account,
groups.organization_id,
groups.name AS group_name,
all_members.group_id
@@ -1627,8 +1632,6 @@ CREATE VIEW group_members_expanded AS
JOIN groups ON ((groups.id = all_members.group_id)))
WHERE (users.deleted = false);
COMMENT ON VIEW group_members_expanded IS 'Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).';
CREATE TABLE inbox_notifications (
id uuid NOT NULL,
user_id uuid NOT NULL,
@@ -3655,6 +3658,10 @@ CREATE INDEX idx_aibridge_interceptions_model ON aibridge_interceptions USING bt
CREATE INDEX idx_aibridge_interceptions_provider ON aibridge_interceptions USING btree (provider);
CREATE INDEX idx_aibridge_interceptions_session_id ON aibridge_interceptions USING btree (session_id) WHERE (ended_at IS NOT NULL);
CREATE INDEX idx_aibridge_interceptions_sessions_filter ON aibridge_interceptions USING btree (initiator_id, started_at DESC, id DESC) WHERE (ended_at IS NOT NULL);
CREATE INDEX idx_aibridge_interceptions_started_id_desc ON aibridge_interceptions USING btree (started_at DESC, id DESC);
CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions USING btree (thread_parent_id);
@@ -3673,6 +3680,8 @@ CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usa
CREATE INDEX idx_aibridge_tool_usagesprovider_response_id ON aibridge_tool_usages USING btree (provider_response_id);
CREATE INDEX idx_aibridge_user_prompts_interception_created ON aibridge_user_prompts USING btree (interception_id, created_at DESC, id DESC);
CREATE INDEX idx_aibridge_user_prompts_interception_id ON aibridge_user_prompts USING btree (interception_id);
CREATE INDEX idx_aibridge_user_prompts_provider_response_id ON aibridge_user_prompts USING btree (provider_response_id);
@@ -0,0 +1,35 @@
DROP VIEW group_members_expanded;
CREATE VIEW group_members_expanded AS
WITH all_members AS (
SELECT group_members.user_id,
group_members.group_id
FROM group_members
UNION
SELECT organization_members.user_id,
organization_members.organization_id AS group_id
FROM organization_members
)
SELECT users.id AS user_id,
users.email AS user_email,
users.username AS user_username,
users.hashed_password AS user_hashed_password,
users.created_at AS user_created_at,
users.updated_at AS user_updated_at,
users.status AS user_status,
users.rbac_roles AS user_rbac_roles,
users.login_type AS user_login_type,
users.avatar_url AS user_avatar_url,
users.deleted AS user_deleted,
users.last_seen_at AS user_last_seen_at,
users.quiet_hours_schedule AS user_quiet_hours_schedule,
users.name AS user_name,
users.github_com_user_id AS user_github_com_user_id,
users.is_system AS user_is_system,
groups.organization_id,
groups.name AS group_name,
all_members.group_id
FROM ((all_members
JOIN users ON ((users.id = all_members.user_id)))
JOIN groups ON ((groups.id = all_members.group_id)))
WHERE (users.deleted = false);
@@ -0,0 +1,36 @@
DROP VIEW group_members_expanded;
CREATE VIEW group_members_expanded AS
WITH all_members AS (
SELECT group_members.user_id,
group_members.group_id
FROM group_members
UNION
SELECT organization_members.user_id,
organization_members.organization_id AS group_id
FROM organization_members
)
SELECT users.id AS user_id,
users.email AS user_email,
users.username AS user_username,
users.hashed_password AS user_hashed_password,
users.created_at AS user_created_at,
users.updated_at AS user_updated_at,
users.status AS user_status,
users.rbac_roles AS user_rbac_roles,
users.login_type AS user_login_type,
users.avatar_url AS user_avatar_url,
users.deleted AS user_deleted,
users.last_seen_at AS user_last_seen_at,
users.quiet_hours_schedule AS user_quiet_hours_schedule,
users.name AS user_name,
users.github_com_user_id AS user_github_com_user_id,
users.is_system AS user_is_system,
users.is_service_account as user_is_service_account,
groups.organization_id,
groups.name AS group_name,
all_members.group_id
FROM ((all_members
JOIN users ON ((users.id = all_members.user_id)))
JOIN groups ON ((groups.id = all_members.group_id)))
WHERE (users.deleted = false);
@@ -0,0 +1,5 @@
DROP INDEX IF EXISTS idx_aibridge_interceptions_session_id;
DROP INDEX IF EXISTS idx_aibridge_user_prompts_interception_created;
DROP INDEX IF EXISTS idx_aibridge_interceptions_sessions_filter;
ALTER TABLE aibridge_interceptions DROP COLUMN IF EXISTS session_id;
@@ -0,0 +1,40 @@
-- A "session" groups related interceptions together. See the COMMENT ON
-- COLUMN below for the full business-logic description.
ALTER TABLE aibridge_interceptions
ADD COLUMN session_id TEXT NOT NULL
GENERATED ALWAYS AS (
COALESCE(
client_session_id,
thread_root_id::text,
id::text
)
) STORED;
-- Searching and grouping on the resolved session ID will be common.
CREATE INDEX idx_aibridge_interceptions_session_id
ON aibridge_interceptions (session_id)
WHERE ended_at IS NOT NULL;
COMMENT ON COLUMN aibridge_interceptions.session_id IS
'Groups related interceptions into a logical session. '
'Determined by a priority chain: '
'(1) client_session_id — an explicit session identifier supplied by the '
'calling client (e.g. Claude Code); '
'(2) thread_root_id — the root of an agentic thread detected by Bridge '
'through tool-call correlation, used when the client does not supply its '
'own session ID; '
'(3) id — the interception''s own ID, used as a last resort so every '
'interception belongs to exactly one session even if it is standalone. '
'This is a generated column stored on disk so it can be indexed and '
'joined without recomputing the COALESCE on every query.';
-- Composite index for the most common filter path used by
-- ListAIBridgeSessions: initiator_id equality + started_at range,
-- with ended_at IS NOT NULL as a partial filter.
CREATE INDEX idx_aibridge_interceptions_sessions_filter
ON aibridge_interceptions (initiator_id, started_at DESC, id DESC)
WHERE ended_at IS NOT NULL;
-- Supports lateral prompt lookup by interception + recency.
CREATE INDEX idx_aibridge_user_prompts_interception_created
ON aibridge_user_prompts (interception_id, created_at DESC, id DESC);
@@ -0,0 +1 @@
ALTER TABLE chat_messages DROP COLUMN provider_response_id;
@@ -0,0 +1 @@
ALTER TABLE chat_messages ADD COLUMN provider_response_id TEXT;
+106
View File
@@ -806,6 +806,8 @@ type aibridgeQuerier interface {
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error)
ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error)
CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error)
}
func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) {
@@ -852,6 +854,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
&i.AIBridgeInterception.ThreadParentID,
&i.AIBridgeInterception.ThreadRootID,
&i.AIBridgeInterception.ClientSessionID,
&i.AIBridgeInterception.SessionID,
&i.VisibleUser.ID,
&i.VisibleUser.Username,
&i.VisibleUser.Name,
@@ -939,6 +942,109 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListA
return items, nil
}
func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) {
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
VariableConverter: regosql.AIBridgeInterceptionConverter(),
})
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(listAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessions :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.AfterSessionID,
arg.Offset,
arg.Limit,
arg.StartedAfter,
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.Model,
arg.Client,
arg.SessionID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListAIBridgeSessionsRow
for rows.Next() {
var i ListAIBridgeSessionsRow
if err := rows.Scan(
&i.SessionID,
&i.UserID,
&i.UserUsername,
&i.UserName,
&i.UserAvatarUrl,
pq.Array(&i.Providers),
pq.Array(&i.Models),
&i.Client,
&i.Metadata,
&i.StartedAt,
&i.EndedAt,
&i.Threads,
&i.InputTokens,
&i.OutputTokens,
&i.LastPrompt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
VariableConverter: regosql.AIBridgeInterceptionConverter(),
})
if err != nil {
return 0, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(countAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return 0, xerrors.Errorf("insert authorized filter: %w", err)
}
query := fmt.Sprintf("-- name: CountAuthorizedAIBridgeSessions :one\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.StartedAfter,
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.Model,
arg.Client,
arg.SessionID,
)
if err != nil {
return 0, err
}
defer rows.Close()
var count int64
for rows.Next() {
if err := rows.Scan(&count); err != nil {
return 0, err
}
}
if err := rows.Close(); err != nil {
return 0, err
}
if err := rows.Err(); err != nil {
return 0, err
}
return count, nil
}
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
if !strings.Contains(query, authorizedQueryPlaceholder) {
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
+4 -1
View File
@@ -4036,6 +4036,8 @@ type AIBridgeInterception struct {
ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"`
// The session ID supplied by the client (optional and not universally supported).
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
// Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception's own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.
SessionID string `db:"session_id" json:"session_id"`
}
// Audit log of model thinking in intercepted requests in AI Bridge
@@ -4227,6 +4229,7 @@ type ChatMessage struct {
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
RuntimeMs sql.NullInt64 `db:"runtime_ms" json:"runtime_ms"`
Deleted bool `db:"deleted" json:"deleted"`
ProviderResponseID sql.NullString `db:"provider_response_id" json:"provider_response_id"`
}
type ChatModelConfig struct {
@@ -4394,7 +4397,6 @@ type Group struct {
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
}
// Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).
type GroupMember struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
UserEmail string `db:"user_email" json:"user_email"`
@@ -4412,6 +4414,7 @@ type GroupMember struct {
UserName string `db:"user_name" json:"user_name"`
UserGithubComUserID sql.NullInt64 `db:"user_github_com_user_id" json:"user_github_com_user_id"`
UserIsSystem bool `db:"user_is_system" json:"user_is_system"`
UserIsServiceAccount bool `db:"user_is_service_account" json:"user_is_service_account"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
GroupName string `db:"group_name" json:"group_name"`
GroupID uuid.UUID `db:"group_id" json:"group_id"`
+1 -1
View File
@@ -231,7 +231,7 @@ type PGPubsub struct {
// BufferSize is the maximum number of unhandled messages we will buffer
// for a subscriber before dropping messages.
const BufferSize = 8192
const BufferSize = 2048
// Subscribe calls the listener when an event matching the name is received.
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
+9 -4
View File
@@ -62,11 +62,9 @@ type sqlcQuerier interface {
// referenced by the latest build of a workspace.
ArchiveUnusedTemplateVersions(ctx context.Context, arg ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error)
BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error
BatchUpdateWorkspaceAgentConnections(ctx context.Context, arg BatchUpdateWorkspaceAgentConnectionsParams) error
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
// Calculates the telemetry summary for a given provider, model, and client
@@ -78,6 +76,7 @@ type sqlcQuerier interface {
CleanTailnetTunnels(ctx context.Context) error
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
// Counts enabled, non-deleted model configs that lack both input and
@@ -150,6 +149,7 @@ type sqlcQuerier interface {
DeleteTailnetPeer(ctx context.Context, arg DeleteTailnetPeerParams) (DeleteTailnetPeerRow, error)
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
@@ -461,9 +461,7 @@ type sqlcQuerier interface {
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error)
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error)
GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerIDsBatchRow, error)
GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error)
GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByOwnerIDAndNameParams) (Task, error)
GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error)
@@ -557,6 +555,7 @@ type sqlcQuerier interface {
GetUserActivityInsights(ctx context.Context, arg GetUserActivityInsightsParams) ([]GetUserActivityInsightsRow, error)
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
@@ -761,6 +760,10 @@ type sqlcQuerier interface {
// (provider, model, client) in the given timeframe for telemetry reporting.
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error)
// Returns paginated sessions with aggregated metadata, token counts, and
// the most recent user prompt. A "session" is a logical grouping of
// interceptions that share the same session_id (set by the client).
ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error)
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
@@ -769,6 +772,7 @@ type sqlcQuerier interface {
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
@@ -872,6 +876,7 @@ type sqlcQuerier interface {
UpdateTemplateVersionFlagsByJobID(ctx context.Context, arg UpdateTemplateVersionFlagsByJobIDParams) error
UpdateTemplateWorkspacesLastUsedAt(ctx context.Context, arg UpdateTemplateWorkspacesLastUsedAtParams) error
UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
+44 -1
View File
@@ -21,7 +21,6 @@ import (
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -35,6 +34,7 @@ import (
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/testutil"
@@ -10417,6 +10417,49 @@ func TestGetPRInsights(t *testing.T) {
assert.Equal(t, int64(0), recent[0].CostMicros)
})
t.Run("BlankDisplayNameFallsBackToModel", func(t *testing.T) {
t.Parallel()
store, userID, _ := setupChatInfra(t)
const modelName = "claude-4.1"
emptyDisplayModel, err := store.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{
Provider: "anthropic",
Model: modelName,
DisplayName: "",
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: userID, Valid: true},
Enabled: true,
IsDefault: false,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat := createChat(t, store, userID, emptyDisplayModel.ID, "chat-empty-display-name")
insertCostMessage(t, store, chat.ID, userID, emptyDisplayModel.ID, 1_000_000)
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/72", "merged", "fix: blank display name", 10, 2, 1)
byModel, err := store.GetPRInsightsPerModel(context.Background(), database.GetPRInsightsPerModelParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
})
require.NoError(t, err)
require.Len(t, byModel, 1)
assert.Equal(t, modelName, byModel[0].DisplayName)
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: noOwner,
LimitVal: 20,
})
require.NoError(t, err)
require.Len(t, recent, 1)
assert.Equal(t, modelName, recent[0].ModelDisplayName)
})
t.Run("MergedCostMicros_OnlyCountsMerged", func(t *testing.T) {
t.Parallel()
store, userID, mcID := setupChatInfra(t)
File diff suppressed because it is too large Load Diff
+188
View File
@@ -404,6 +404,194 @@ SELECT (
(SELECT COUNT(*) FROM interceptions)
)::bigint as total_deleted;
-- name: CountAIBridgeSessions :one
SELECT
COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id))
FROM
aibridge_interceptions
WHERE
-- Remove inflight interceptions (ones which lack an ended_at value).
aibridge_interceptions.ended_at IS NOT NULL
-- Filter by time frame
AND CASE
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
ELSE true
END
AND CASE
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
ELSE true
END
-- Filter initiator_id
AND CASE
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
ELSE true
END
-- Filter provider
AND CASE
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
ELSE true
END
-- Filter model
AND CASE
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
ELSE true
END
-- Filter client
AND CASE
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
ELSE true
END
-- Filter session_id
AND CASE
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions
-- @authorize_filter
;
-- name: ListAIBridgeSessions :many
-- Returns paginated sessions with aggregated metadata, token counts, and
-- the most recent user prompt. A "session" is a logical grouping of
-- interceptions that share the same session_id (set by the client).
WITH filtered_interceptions AS (
SELECT
aibridge_interceptions.*
FROM
aibridge_interceptions
WHERE
-- Remove inflight interceptions (ones which lack an ended_at value).
aibridge_interceptions.ended_at IS NOT NULL
-- Filter by time frame
AND CASE
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
ELSE true
END
AND CASE
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
ELSE true
END
-- Filter initiator_id
AND CASE
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
ELSE true
END
-- Filter provider
AND CASE
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
ELSE true
END
-- Filter model
AND CASE
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
ELSE true
END
-- Filter client
AND CASE
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
ELSE true
END
-- Filter session_id
AND CASE
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
ELSE true
END
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
-- @authorize_filter
),
session_tokens AS (
-- Aggregate token usage across all interceptions in each session.
-- Group by (session_id, initiator_id) to avoid merging sessions from
-- different users who happen to share the same client_session_id.
SELECT
fi.session_id,
fi.initiator_id,
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
-- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands.
FROM
filtered_interceptions fi
LEFT JOIN
aibridge_token_usages tu ON fi.id = tu.interception_id
GROUP BY
fi.session_id, fi.initiator_id
),
session_root AS (
-- Build one summary row per session. Group by (session_id, initiator_id)
-- to avoid merging sessions from different users who happen to share the
-- same client_session_id. The ARRAY_AGG with ORDER BY picks values from
-- the chronologically first interception for fields that should represent
-- the session as a whole (client, metadata). Threads are counted as
-- distinct root interception IDs: an interception with a NULL
-- thread_root_id is itself a thread root.
SELECT
fi.session_id,
fi.initiator_id,
(ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client,
(ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata,
ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers,
ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models,
MIN(fi.started_at) AS started_at,
MAX(fi.ended_at) AS ended_at,
COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads,
-- Collect IDs for lateral prompt lookup.
ARRAY_AGG(fi.id) AS interception_ids
FROM
filtered_interceptions fi
GROUP BY
fi.session_id, fi.initiator_id
)
SELECT
sr.session_id,
visible_users.id AS user_id,
visible_users.username AS user_username,
visible_users.name AS user_name,
visible_users.avatar_url AS user_avatar_url,
sr.providers::text[] AS providers,
sr.models::text[] AS models,
COALESCE(sr.client, '')::varchar(64) AS client,
sr.metadata::jsonb AS metadata,
sr.started_at::timestamptz AS started_at,
sr.ended_at::timestamptz AS ended_at,
sr.threads,
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
COALESCE(slp.prompt, '') AS last_prompt
FROM
session_root sr
JOIN
visible_users ON visible_users.id = sr.initiator_id
LEFT JOIN
session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id
LEFT JOIN LATERAL (
-- Lateral join to efficiently fetch only the most recent user prompt
-- across all interceptions in the session, avoiding a full aggregation.
SELECT up.prompt
FROM aibridge_user_prompts up
WHERE up.interception_id = ANY(sr.interception_ids)
ORDER BY up.created_at DESC, up.id DESC
LIMIT 1
) slp ON true
WHERE
-- Cursor pagination: uses a composite (started_at, session_id) cursor
-- to support keyset pagination. The less-than comparison matches the
-- DESC sort order so that rows after the cursor come later in results.
CASE
WHEN @after_session_id::text != '' THEN (
(sr.started_at, sr.session_id) < (
(SELECT started_at FROM session_root WHERE session_id = @after_session_id),
@after_session_id::text
)
)
ELSE true
END
ORDER BY
sr.started_at DESC,
sr.session_id DESC
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
OFFSET @offset_
;
-- name: ListAIBridgeModels :many
SELECT
model
+4 -3
View File
@@ -147,6 +147,7 @@ deduped AS (
cds.deletions,
cmc.id AS model_config_id,
cmc.display_name,
cmc.model,
cmc.provider
FROM chat_diff_statuses cds
JOIN chats c ON c.id = cds.chat_id
@@ -159,7 +160,7 @@ deduped AS (
)
SELECT
d.model_config_id,
COALESCE(d.display_name, 'Unknown')::text AS display_name,
COALESCE(NULLIF(d.display_name, ''), NULLIF(d.model, ''), 'Unknown')::text AS display_name,
COALESCE(d.provider, 'unknown')::text AS provider,
COUNT(*)::bigint AS total_prs,
COUNT(*) FILTER (WHERE d.pull_request_state = 'merged')::bigint AS merged_prs,
@@ -169,7 +170,7 @@ SELECT
COALESCE(SUM(pc.cost_micros) FILTER (WHERE d.pull_request_state = 'merged'), 0)::bigint AS merged_cost_micros
FROM deduped d
JOIN pr_costs pc ON pc.pr_key = d.pr_key
GROUP BY d.model_config_id, d.display_name, d.provider
GROUP BY d.model_config_id, d.display_name, d.model, d.provider
ORDER BY total_prs DESC;
-- name: GetPRInsightsRecentPRs :many
@@ -227,7 +228,7 @@ deduped AS (
cds.author_login,
cds.author_avatar_url,
COALESCE(cds.base_branch, '')::text AS base_branch,
COALESCE(cmc.display_name, cmc.model, 'Unknown')::text AS model_display_name,
COALESCE(NULLIF(cmc.display_name, ''), NULLIF(cmc.model, ''), 'Unknown')::text AS model_display_name,
c.created_at
FROM chat_diff_statuses cds
JOIN chats c ON c.id = cds.chat_id
+4 -2
View File
@@ -241,7 +241,8 @@ INSERT INTO chat_messages (
context_limit,
compressed,
total_cost_micros,
runtime_ms
runtime_ms,
provider_response_id
)
SELECT
@chat_id::uuid,
@@ -260,7 +261,8 @@ SELECT
NULLIF(UNNEST(@context_limit::bigint[]), 0),
UNNEST(@compressed::boolean[]),
NULLIF(UNNEST(@total_cost_micros::bigint[]), 0),
NULLIF(UNNEST(@runtime_ms::bigint[]), 0)
NULLIF(UNNEST(@runtime_ms::bigint[]), 0),
NULLIF(UNNEST(@provider_response_id::text[]), '')
RETURNING
*;
@@ -303,44 +303,3 @@ DO UPDATE SET
ELSE connection_logs.code
END
RETURNING *;
-- name: BatchUpsertConnectionLogs :exec
INSERT INTO connection_logs (
id, connect_time, organization_id, workspace_owner_id, workspace_id,
workspace_name, agent_name, type, code, ip, user_agent, user_id,
slug_or_port, connection_id, disconnect_reason, disconnect_time
)
SELECT
unnest(sqlc.arg('id')::uuid[]),
unnest(sqlc.arg('connect_time')::timestamptz[]),
unnest(sqlc.arg('organization_id')::uuid[]),
unnest(sqlc.arg('workspace_owner_id')::uuid[]),
unnest(sqlc.arg('workspace_id')::uuid[]),
unnest(sqlc.arg('workspace_name')::text[]),
unnest(sqlc.arg('agent_name')::text[]),
unnest(sqlc.arg('type')::connection_type[]),
unnest(sqlc.arg('code')::int4[]),
unnest(sqlc.arg('ip')::inet[]),
unnest(sqlc.arg('user_agent')::text[]),
unnest(sqlc.arg('user_id')::uuid[]),
unnest(sqlc.arg('slug_or_port')::text[]),
unnest(sqlc.arg('connection_id')::uuid[]),
unnest(sqlc.arg('disconnect_reason')::text[]),
unnest(sqlc.arg('disconnect_time')::timestamptz[])
ON CONFLICT (connection_id, workspace_id, agent_name)
DO UPDATE SET
disconnect_time = CASE
WHEN connection_logs.disconnect_time IS NULL
THEN EXCLUDED.disconnect_time
ELSE connection_logs.disconnect_time
END,
disconnect_reason = CASE
WHEN connection_logs.disconnect_reason IS NULL
THEN EXCLUDED.disconnect_reason
ELSE connection_logs.disconnect_reason
END,
code = CASE
WHEN connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END;
+100 -6
View File
@@ -5,7 +5,9 @@
-- - Use both to get a specific org member row
SELECT
sqlc.embed(organization_members),
users.username, users.avatar_url, users.name, users.email, users.rbac_roles as "global_roles"
users.username, users.avatar_url, users.name, users.email, users.rbac_roles as "global_roles",
users.last_seen_at, users.status, users.login_type,
users.created_at as user_created_at, users.updated_at as user_updated_at
FROM
organization_members
INNER JOIN
@@ -83,23 +85,115 @@ RETURNING *;
SELECT
sqlc.embed(organization_members),
users.username, users.avatar_url, users.name, users.email, users.rbac_roles as "global_roles",
users.last_seen_at, users.status, users.login_type,
users.created_at as user_created_at, users.updated_at as user_updated_at,
COUNT(*) OVER() AS count
FROM
organization_members
INNER JOIN
INNER JOIN
users ON organization_members.user_id = users.id AND users.deleted = false
WHERE
-- Filter by organization id
CASE
-- This allows using the last element on a page as effectively a cursor.
-- This is an important option for scripts that need to paginate without
-- duplicating or missing data.
WHEN @after_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
-- The pagination cursor is the last ID of the previous page.
-- The query is ordered by the username field, so select all
-- rows after the cursor.
(LOWER(users.username)) > (
SELECT
LOWER(users.username)
FROM
organization_members
INNER JOIN
users ON organization_members.user_id = users.id
WHERE
organization_members.user_id = @after_id
)
)
ELSE true
END
-- Start filters
-- Filter by organization id
AND CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
organization_id = @organization_id
ELSE true
END
-- Filter by system type
AND CASE WHEN @include_system::bool THEN TRUE ELSE is_system = false END
-- Filter by email or username
AND CASE
WHEN @search :: text != '' THEN (
users.email ILIKE concat('%', @search, '%')
OR users.username ILIKE concat('%', @search, '%')
)
ELSE true
END
-- Filter by name (display name)
AND CASE
WHEN @name :: text != '' THEN
users.name ILIKE concat('%', @name, '%')
ELSE true
END
-- Filter by status
AND CASE
-- @status needs to be a text because it can be empty, If it was
-- user_status enum, it would not.
WHEN cardinality(@status :: user_status[]) > 0 THEN
users.status = ANY(@status :: user_status[])
ELSE true
END
-- Filter by global rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
-- everyone is a member.
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[]) THEN
users.rbac_roles && @rbac_role :: text[]
ELSE true
END
-- Filter by last_seen
AND CASE
WHEN @last_seen_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
users.last_seen_at <= @last_seen_before
ELSE true
END
AND CASE
WHEN @last_seen_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
users.last_seen_at >= @last_seen_after
ELSE true
END
-- Filter by created_at (user creation date, not date added to org)
AND CASE
WHEN @created_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
users.created_at <= @created_before
ELSE true
END
AND CASE
WHEN @created_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
users.created_at >= @created_after
ELSE true
END
-- Filter by system type
AND CASE
WHEN @include_system::bool THEN TRUE
ELSE users.is_system = false
END
-- Filter by github.com user ID
AND CASE
WHEN @github_com_user_id :: bigint != 0 THEN
users.github_com_user_id = @github_com_user_id
ELSE true
END
-- Filter by login_type
AND CASE
WHEN cardinality(@login_type :: login_type[]) > 0 THEN
users.login_type = ANY(@login_type :: login_type[])
ELSE true
END
-- End of filters
ORDER BY
-- Deterministic and consistent ordering of all users. This is to ensure consistent pagination.
LOWER(username) ASC OFFSET @offset_opt
LOWER(users.username) ASC OFFSET @offset_opt
LIMIT
-- A null limit means "no limit", so 0 means return all
NULLIF(@limit_opt :: int, 0);
-20
View File
@@ -118,26 +118,6 @@ WHERE id IN (
WHERE tailnet_tunnels.dst_id = $1
);
-- name: GetTailnetTunnelPeerIDsBatch :many
SELECT src_id AS lookup_id, dst_id AS peer_id, coordinator_id, updated_at
FROM tailnet_tunnels WHERE src_id = ANY(@ids :: uuid[])
UNION ALL
SELECT dst_id AS lookup_id, src_id AS peer_id, coordinator_id, updated_at
FROM tailnet_tunnels WHERE dst_id = ANY(@ids :: uuid[]);
-- name: GetTailnetTunnelPeerBindingsBatch :many
SELECT tp.id AS peer_id, tp.coordinator_id, tp.updated_at, tp.node, tp.status,
tt.src_id AS lookup_id
FROM tailnet_peers tp
INNER JOIN tailnet_tunnels tt ON tp.id = tt.dst_id
WHERE tt.src_id = ANY(@ids :: uuid[])
UNION ALL
SELECT tp.id AS peer_id, tp.coordinator_id, tp.updated_at, tp.node, tp.status,
tt.dst_id AS lookup_id
FROM tailnet_peers tp
INNER JOIN tailnet_tunnels tt ON tp.id = tt.src_id
WHERE tt.dst_id = ANY(@ids :: uuid[]);
-- For PG Coordinator HTMLDebug
-- name: GetAllTailnetCoordinators :many
+20
View File
@@ -193,6 +193,26 @@ WHERE user_configs.user_id = @user_id
AND user_configs.key = 'chat_custom_prompt'
RETURNING *;
-- name: ListUserChatCompactionThresholds :many
SELECT user_id, key, value FROM user_configs
WHERE user_id = @user_id
AND key LIKE 'chat\_compaction\_threshold\_pct:%'
ORDER BY key;
-- name: GetUserChatCompactionThreshold :one
SELECT value AS threshold_percent FROM user_configs
WHERE user_id = @user_id AND key = @key;
-- name: UpdateUserChatCompactionThreshold :one
INSERT INTO user_configs (user_id, key, value)
VALUES (@user_id, @key, (@threshold_percent::int)::text)
ON CONFLICT ON CONSTRAINT user_configs_pkey
DO UPDATE SET value = (@threshold_percent::int)::text
RETURNING *;
-- name: DeleteUserChatCompactionThreshold :exec
DELETE FROM user_configs WHERE user_id = @user_id AND key = @key;
-- name: GetUserTaskNotificationAlertDismissed :one
SELECT
value::boolean as task_notification_alert_dismissed
@@ -78,29 +78,6 @@ SET
WHERE
id = $1;
-- name: BatchUpdateWorkspaceAgentConnections :exec
WITH agents AS (
SELECT
unnest(sqlc.arg('id')::uuid[]) AS id,
unnest(sqlc.arg('first_connected_at')::timestamptz[]) AS first_connected_at,
unnest(sqlc.arg('last_connected_at')::timestamptz[]) AS last_connected_at,
unnest(sqlc.arg('last_connected_replica_id')::uuid[]) AS last_connected_replica_id,
unnest(sqlc.arg('disconnected_at')::timestamptz[]) AS disconnected_at,
unnest(sqlc.arg('updated_at')::timestamptz[]) AS updated_at
)
UPDATE
workspace_agents wa
SET
first_connected_at = a.first_connected_at,
last_connected_at = a.last_connected_at,
last_connected_replica_id = a.last_connected_replica_id,
disconnected_at = a.disconnected_at,
updated_at = a.updated_at
FROM
agents a
WHERE
wa.id = a.id;
-- name: UpdateWorkspaceAgentStartupByID :exec
UPDATE
workspace_agents
+2 -2
View File
@@ -3,7 +3,7 @@ package dynamicparameters
import (
"fmt"
"net/http"
"sort"
"slices"
"github.com/hashicorp/hcl/v2"
@@ -94,7 +94,7 @@ func (e *DiagnosticError) Response() (int, codersdk.Response) {
for name := range e.KeyedDiagnostics {
sortedNames = append(sortedNames, name)
}
sort.Strings(sortedNames)
slices.Sort(sortedNames)
for _, name := range sortedNames {
diag := e.KeyedDiagnostics[name]
+178 -3
View File
@@ -28,14 +28,11 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpapi/httperror"
"github.com/coder/coder/v2/coderd/httpmw"
@@ -46,6 +43,9 @@ import (
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/gitsync"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/websocket"
@@ -2542,6 +2542,17 @@ func normalizeChatCompressionThreshold(
return threshold, nil
}
func parseCompactionThresholdKey(key string) (uuid.UUID, error) {
if !strings.HasPrefix(key, codersdk.ChatCompactionThresholdKeyPrefix) {
return uuid.Nil, xerrors.Errorf("invalid compaction threshold key: %q", key)
}
id, err := uuid.Parse(key[len(codersdk.ChatCompactionThresholdKeyPrefix):])
if err != nil {
return uuid.Nil, xerrors.Errorf("invalid model config ID in key %q: %w", key, err)
}
return id, nil
}
const (
// maxChatFileSize is the maximum size of a chat file upload (10 MB).
maxChatFileSize = 10 << 20
@@ -2816,6 +2827,170 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
})
}
// @Summary Get user chat compaction thresholds
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
func (api *API) getUserChatCompactionThresholds(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apiKey = httpmw.APIKey(r)
)
rows, err := api.Database.ListUserChatCompactionThresholds(ctx, apiKey.UserID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Error listing user chat compaction thresholds.",
Detail: err.Error(),
})
return
}
resp := codersdk.UserChatCompactionThresholds{
Thresholds: make([]codersdk.UserChatCompactionThreshold, 0, len(rows)),
}
for _, row := range rows {
modelConfigID, err := parseCompactionThresholdKey(row.Key)
if err != nil {
api.Logger.Warn(ctx, "skipping malformed user chat compaction threshold key",
slog.F("key", row.Key),
slog.F("value", row.Value),
slog.Error(err),
)
continue
}
thresholdPercent, err := strconv.ParseInt(row.Value, 10, 32)
if err != nil {
api.Logger.Warn(ctx, "skipping malformed user chat compaction threshold value",
slog.F("key", row.Key),
slog.F("value", row.Value),
slog.Error(err),
)
continue
}
if thresholdPercent < int64(minChatContextCompressionThreshold) ||
thresholdPercent > int64(maxChatContextCompressionThreshold) {
api.Logger.Warn(ctx, "skipping out-of-range user chat compaction threshold",
slog.F("key", row.Key),
slog.F("value", row.Value),
)
continue
}
resp.Thresholds = append(resp.Thresholds, codersdk.UserChatCompactionThreshold{
ModelConfigID: modelConfigID,
ThresholdPercent: int32(thresholdPercent),
})
}
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
// @Summary Set user chat compaction threshold for a model config
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) putUserChatCompactionThreshold(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apiKey = httpmw.APIKey(r)
)
modelConfigID, ok := parseChatModelConfigID(rw, r)
if !ok {
return
}
var req codersdk.UpdateUserChatCompactionThresholdRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if req.ThresholdPercent < minChatContextCompressionThreshold ||
req.ThresholdPercent > maxChatContextCompressionThreshold {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "threshold_percent is out of range.",
Detail: fmt.Sprintf(
"threshold_percent must be between %d and %d, got %d.",
minChatContextCompressionThreshold,
maxChatContextCompressionThreshold,
req.ThresholdPercent,
),
})
return
}
// Use system context because GetChatModelConfigByID requires
// deployment-config read access, which non-admin users lack.
// The user is only checking if the model exists and is enabled
// before writing their own personal preference.
//nolint:gocritic // Non-admin users need this lookup to save their own setting.
modelConfig, err := api.Database.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), modelConfigID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) || httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get chat model config.",
Detail: err.Error(),
})
return
}
if !modelConfig.Enabled {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Model config is disabled.",
})
return
}
_, err = api.Database.UpdateUserChatCompactionThreshold(ctx, database.UpdateUserChatCompactionThresholdParams{
UserID: apiKey.UserID,
Key: codersdk.CompactionThresholdKey(modelConfigID),
ThresholdPercent: req.ThresholdPercent,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Error updating user chat compaction threshold.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCompactionThreshold{
ModelConfigID: modelConfigID,
ThresholdPercent: req.ThresholdPercent,
})
}
// @Summary Delete user chat compaction threshold for a model config
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) deleteUserChatCompactionThreshold(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apiKey = httpmw.APIKey(r)
)
modelConfigID, ok := parseChatModelConfigID(rw, r)
if !ok {
return
}
if err := api.Database.DeleteUserChatCompactionThreshold(ctx, database.DeleteUserChatCompactionThresholdParams{
UserID: apiKey.UserID,
Key: codersdk.CompactionThresholdKey(modelConfigID),
}); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Error deleting user chat compaction threshold.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) resolvedChatSystemPrompt(ctx context.Context) string {
custom, err := api.Database.GetChatSystemPrompt(ctx)
if err != nil {
File diff suppressed because it is too large Load Diff
+10 -27
View File
@@ -21,11 +21,14 @@ import (
func TestPostFiles(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each sub-test
// creates independent resources with unique IDs so parallel
// execution is safe.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("BadContentType", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@@ -35,9 +38,6 @@ func TestPostFiles(t *testing.T) {
t.Run("Insert", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@@ -47,9 +47,6 @@ func TestPostFiles(t *testing.T) {
t.Run("InsertWindowsZip", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@@ -59,9 +56,6 @@ func TestPostFiles(t *testing.T) {
t.Run("InsertAlreadyExists", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@@ -73,9 +67,6 @@ func TestPostFiles(t *testing.T) {
})
t.Run("InsertConcurrent", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@@ -99,11 +90,12 @@ func TestPostFiles(t *testing.T) {
func TestDownload(t *testing.T) {
t.Parallel()
// Shared instance — see TestPostFiles for rationale.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@@ -115,9 +107,6 @@ func TestDownload(t *testing.T) {
t.Run("InsertTar_DownloadTar", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
// given
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@@ -139,9 +128,6 @@ func TestDownload(t *testing.T) {
t.Run("InsertZip_DownloadTar", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
// given
zipContent := archivetest.TestZipFileBytes()
@@ -164,9 +150,6 @@ func TestDownload(t *testing.T) {
t.Run("InsertTar_DownloadZip", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
// given
tarball := archivetest.TestTarFileBytes()
+50 -28
View File
@@ -248,12 +248,9 @@ func PrecheckAPIKey(cfg ValidateAPIKeyConfig) func(http.Handler) http.Handler {
//
// Returns (result, nil) on success or (nil, error) on failure.
func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Request) (*ValidateAPIKeyResult, *ValidateAPIKeyError) {
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
if !ok {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: resp,
}
key, valErr := apiKeyFromRequestValidate(ctx, cfg.DB, cfg.SessionTokenFunc, r)
if valErr != nil {
return nil, valErr
}
// Log the API key ID for all requests that have a valid key
@@ -475,7 +472,7 @@ func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Reque
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
@@ -492,6 +489,15 @@ func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Reque
}
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
key, valErr := apiKeyFromRequestValidate(ctx, db, sessionTokenFunc, r)
if valErr != nil {
return nil, valErr.Response, false
}
return key, codersdk.Response{}, true
}
func apiKeyFromRequestValidate(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, *ValidateAPIKeyError) {
tokenFunc := APITokenFromRequest
if sessionTokenFunc != nil {
tokenFunc = sessionTokenFunc
@@ -499,45 +505,61 @@ func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc
token := tokenFunc(r)
if token == "" {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
}, false
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
},
}
}
keyID, keySecret, err := SplitAPIToken(token)
if err != nil {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "Invalid API key format: " + err.Error(),
}, false
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "Invalid API key format: " + err.Error(),
},
}
}
//nolint:gocritic // System needs to fetch API key to check if it's valid.
key, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key is invalid.",
}, false
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key is invalid.",
},
}
}
return nil, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
}, false
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
},
Hard: true,
}
}
// Checking to see if the secret is valid.
if !apikey.ValidateHash(key.HashedSecret, keySecret) {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key secret is invalid.",
}, false
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key secret is invalid.",
},
}
}
return &key, codersdk.Response{}, true
return &key, nil
}
// ExtractAPIKey requires authentication using a valid API key. It handles
+27
View File
@@ -19,12 +19,14 @@ import (
"go.uber.org/mock/gomock"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
@@ -192,6 +194,31 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("GetAPIKeyByIDInternalError", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
id, secret, _ := randomAPIKeyParts()
r := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret))
db.EXPECT().GetAPIKeyByID(gomock.Any(), id).Return(database.APIKey{}, xerrors.New("db unavailable"))
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusInternalServerError, res.StatusCode)
var resp codersdk.Response
require.NoError(t, json.NewDecoder(res.Body).Decode(&resp))
require.NotEqual(t, httpmw.SignedOutErrorMessage, resp.Message)
require.Contains(t, resp.Detail, "Internal error fetching API key by id")
})
t.Run("UserLinkNotFound", func(t *testing.T) {
t.Parallel()
var (
+5 -5
View File
@@ -14,9 +14,13 @@ import (
func TestInitScript(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. All operations
// are read-only (fetching init scripts) so parallel execution
// is safe.
client := coderdtest.New(t, nil)
t.Run("OK Windows amd64", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
script, err := client.InitScript(context.Background(), "windows", "amd64")
require.NoError(t, err)
require.NotEmpty(t, script)
@@ -26,7 +30,6 @@ func TestInitScript(t *testing.T) {
t.Run("OK Windows arm64", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
script, err := client.InitScript(context.Background(), "windows", "arm64")
require.NoError(t, err)
require.NotEmpty(t, script)
@@ -36,7 +39,6 @@ func TestInitScript(t *testing.T) {
t.Run("OK Linux amd64", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
script, err := client.InitScript(context.Background(), "linux", "amd64")
require.NoError(t, err)
require.NotEmpty(t, script)
@@ -46,7 +48,6 @@ func TestInitScript(t *testing.T) {
t.Run("OK Linux arm64", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
script, err := client.InitScript(context.Background(), "linux", "arm64")
require.NoError(t, err)
require.NotEmpty(t, script)
@@ -56,7 +57,6 @@ func TestInitScript(t *testing.T) {
t.Run("BadRequest", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_, err := client.InitScript(context.Background(), "darwin", "armv7")
require.Error(t, err)
var apiErr *codersdk.Error
+207 -2
View File
@@ -6,12 +6,16 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/client/transport"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
@@ -107,9 +111,147 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
// Validate auth-type-dependent fields.
switch req.AuthType {
case "oauth2":
if req.OAuth2ClientID == "" || req.OAuth2AuthURL == "" || req.OAuth2TokenURL == "" {
// When the admin does not provide OAuth2 credentials, attempt
// automatic discovery and Dynamic Client Registration (RFC 7591)
// using the MCP server URL. This follows the MCP authorization
// spec: discover the authorization server via Protected Resource
// Metadata (RFC 9728) and Authorization Server Metadata
// (RFC 8414), then register a client dynamically.
if req.OAuth2ClientID == "" && req.OAuth2AuthURL == "" && req.OAuth2TokenURL == "" {
// Auto-discovery flow: we need the config ID first to
// build the correct callback URL. Insert the record
// with empty OAuth2 fields, perform discovery, then
// update.
customHeadersJSON, err := marshalCustomHeaders(req.CustomHeaders)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid custom headers.",
Detail: err.Error(),
})
return
}
inserted, err := api.Database.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
DisplayName: strings.TrimSpace(req.DisplayName),
Slug: strings.TrimSpace(req.Slug),
Description: strings.TrimSpace(req.Description),
IconURL: strings.TrimSpace(req.IconURL),
Transport: strings.TrimSpace(req.Transport),
Url: strings.TrimSpace(req.URL),
AuthType: strings.TrimSpace(req.AuthType),
OAuth2ClientID: "",
OAuth2ClientSecret: "",
OAuth2ClientSecretKeyID: sql.NullString{},
OAuth2AuthURL: "",
OAuth2TokenURL: "",
OAuth2Scopes: "",
APIKeyHeader: strings.TrimSpace(req.APIKeyHeader),
APIKeyValue: strings.TrimSpace(req.APIKeyValue),
APIKeyValueKeyID: sql.NullString{},
CustomHeaders: customHeadersJSON,
CustomHeadersKeyID: sql.NullString{},
ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)),
ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)),
Availability: strings.TrimSpace(req.Availability),
Enabled: req.Enabled,
CreatedBy: apiKey.UserID,
UpdatedBy: apiKey.UserID,
})
if err != nil {
switch {
case database.IsUniqueViolation(err):
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "MCP server config already exists.",
Detail: err.Error(),
})
return
case database.IsCheckViolation(err):
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid MCP server config.",
Detail: err.Error(),
})
return
default:
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to create MCP server config.",
Detail: err.Error(),
})
return
}
}
// Now build the callback URL with the actual ID.
callbackURL := fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/callback", api.AccessURL.String(), inserted.ID)
result, err := discoverAndRegisterMCPOAuth2(ctx, strings.TrimSpace(req.URL), callbackURL)
if err != nil {
// Clean up: delete the partially created config.
deleteErr := api.Database.DeleteMCPServerConfigByID(ctx, inserted.ID)
if deleteErr != nil {
api.Logger.Warn(ctx, "failed to clean up MCP server config after OAuth2 discovery failure",
slog.F("config_id", inserted.ID),
slog.Error(deleteErr),
)
}
api.Logger.Warn(ctx, "mcp oauth2 auto-discovery failed",
slog.F("url", req.URL),
slog.Error(err),
)
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "OAuth2 auto-discovery failed. Provide oauth2_client_id, oauth2_auth_url, and oauth2_token_url manually, or ensure the MCP server supports RFC 9728 (Protected Resource Metadata) and RFC 7591 (Dynamic Client Registration).",
Detail: err.Error(),
})
return
}
// Determine scopes: use the request value if provided,
// otherwise fall back to the discovered value.
oauth2Scopes := strings.TrimSpace(req.OAuth2Scopes)
if oauth2Scopes == "" {
oauth2Scopes = result.scopes
}
// Update the record with discovered OAuth2 credentials.
updated, err := api.Database.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{
ID: inserted.ID,
DisplayName: inserted.DisplayName,
Slug: inserted.Slug,
Description: inserted.Description,
IconURL: inserted.IconURL,
Transport: inserted.Transport,
Url: inserted.Url,
AuthType: inserted.AuthType,
OAuth2ClientID: result.clientID,
OAuth2ClientSecret: result.clientSecret,
OAuth2ClientSecretKeyID: sql.NullString{},
OAuth2AuthURL: result.authURL,
OAuth2TokenURL: result.tokenURL,
OAuth2Scopes: oauth2Scopes,
APIKeyHeader: inserted.APIKeyHeader,
APIKeyValue: inserted.APIKeyValue,
APIKeyValueKeyID: inserted.APIKeyValueKeyID,
CustomHeaders: inserted.CustomHeaders,
CustomHeadersKeyID: inserted.CustomHeadersKeyID,
ToolAllowList: inserted.ToolAllowList,
ToolDenyList: inserted.ToolDenyList,
Availability: inserted.Availability,
Enabled: inserted.Enabled,
UpdatedBy: apiKey.UserID,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update MCP server config with OAuth2 credentials.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusCreated, convertMCPServerConfig(updated))
return
} else if req.OAuth2ClientID == "" || req.OAuth2AuthURL == "" || req.OAuth2TokenURL == "" {
// Partial manual config: all three fields are required together.
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "OAuth2 auth type requires oauth2_client_id, oauth2_auth_url, and oauth2_token_url.",
Message: "OAuth2 auth type requires either all of oauth2_client_id, oauth2_auth_url, and oauth2_token_url (manual configuration), or none of them (automatic discovery via RFC 7591).",
})
return
}
@@ -919,3 +1061,66 @@ func coalesceStringSlice(ss []string) []string {
}
return ss
}
// mcpOAuth2Discovery holds the result of MCP OAuth2 auto-discovery
// and Dynamic Client Registration.
type mcpOAuth2Discovery struct {
clientID string
clientSecret string
authURL string
tokenURL string
scopes string // space-separated
}
// discoverAndRegisterMCPOAuth2 uses the mcp-go library's OAuthHandler to
// perform the MCP OAuth2 discovery and Dynamic Client Registration flow:
//
// 1. Discover the authorization server via Protected Resource Metadata
// (RFC 9728) and Authorization Server Metadata (RFC 8414).
// 2. Register a client via Dynamic Client Registration (RFC 7591).
// 3. Return the discovered endpoints and generated credentials.
func discoverAndRegisterMCPOAuth2(ctx context.Context, mcpServerURL, callbackURL string) (*mcpOAuth2Discovery, error) {
// Per the MCP spec, the authorization base URL is the MCP server
// URL with the path component discarded (scheme + host only).
parsed, err := url.Parse(mcpServerURL)
if err != nil {
return nil, xerrors.Errorf("parse MCP server URL: %w", err)
}
origin := fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)
oauthHandler := transport.NewOAuthHandler(transport.OAuthConfig{
RedirectURI: callbackURL,
TokenStore: transport.NewMemoryTokenStore(),
})
oauthHandler.SetBaseURL(origin)
// Step 1: Discover authorization server metadata (RFC 9728 + RFC 8414).
metadata, err := oauthHandler.GetServerMetadata(ctx)
if err != nil {
return nil, xerrors.Errorf("discover authorization server: %w", err)
}
if metadata.AuthorizationEndpoint == "" {
return nil, xerrors.New("authorization server metadata missing authorization_endpoint")
}
if metadata.TokenEndpoint == "" {
return nil, xerrors.New("authorization server metadata missing token_endpoint")
}
if metadata.RegistrationEndpoint == "" {
return nil, xerrors.New("authorization server does not advertise a registration_endpoint (dynamic client registration may not be supported)")
}
// Step 2: Register a client via Dynamic Client Registration (RFC 7591).
if err := oauthHandler.RegisterClient(ctx, "Coder"); err != nil {
return nil, xerrors.Errorf("dynamic client registration: %w", err)
}
scopes := strings.Join(metadata.ScopesSupported, " ")
return &mcpOAuth2Discovery{
clientID: oauthHandler.GetClientID(),
clientSecret: oauthHandler.GetClientSecret(),
authURL: metadata.AuthorizationEndpoint,
tokenURL: metadata.TokenEndpoint,
scopes: scopes,
}, nil
}
+310 -4
View File
@@ -3,6 +3,7 @@ package coderd_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
@@ -430,6 +431,309 @@ func TestMCPServerConfigsOAuth2Disconnect(t *testing.T) {
require.NoError(t, err)
}
func TestMCPServerConfigsOAuth2AutoDiscovery(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// Stand up a mock auth server that serves RFC 8414 metadata and
// a RFC 7591 dynamic client registration endpoint.
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/oauth-authorization-server":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"issuer": "` + r.Host + `",
"authorization_endpoint": "` + "http://" + r.Host + `/authorize",
"token_endpoint": "` + "http://" + r.Host + `/token",
"registration_endpoint": "` + "http://" + r.Host + `/register",
"response_types_supported": ["code"],
"scopes_supported": ["read", "write"]
}`))
case "/register":
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{
"client_id": "auto-discovered-client-id",
"client_secret": "auto-discovered-client-secret"
}`))
default:
http.NotFound(w, r)
}
}))
t.Cleanup(authServer.Close)
// Stand up a mock MCP server that serves RFC 9728 Protected
// Resource Metadata pointing to the auth server above.
mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/oauth-protected-resource" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"resource": "` + "http://" + r.Host + `",
"authorization_servers": ["` + authServer.URL + `"]
}`))
return
}
http.NotFound(w, r)
}))
t.Cleanup(mcpServer.Close)
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
// Create config with auth_type=oauth2 but no OAuth2 fields —
// the server should auto-discover them.
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Auto-Discovery Server",
Slug: "auto-discovery",
Transport: "streamable_http",
URL: mcpServer.URL + "/v1/mcp",
AuthType: "oauth2",
Availability: "default_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
require.Equal(t, "auto-discovered-client-id", created.OAuth2ClientID)
require.True(t, created.HasOAuth2Secret)
require.Equal(t, authServer.URL+"/authorize", created.OAuth2AuthURL)
require.Equal(t, authServer.URL+"/token", created.OAuth2TokenURL)
require.Equal(t, "read write", created.OAuth2Scopes)
})
// Regression test: verify that during dynamic client registration
// the redirect_uris sent to the authorization server contain the
// real config UUID, NOT the literal string "{id}". Before the
// fix, the callback URL was built before the config row existed,
// so it contained "{id}" literally, which caused "redirect URIs
// not approved" errors when the user later tried to connect.
t.Run("RedirectURIContainsRealConfigID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// Buffered channel so the handler never blocks.
registeredRedirectURI := make(chan string, 1)
// Stand up a mock auth server that captures the redirect_uris
// from the RFC 7591 Dynamic Client Registration request.
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/oauth-authorization-server":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"issuer": "` + "http://" + r.Host + `",
"authorization_endpoint": "` + "http://" + r.Host + `/authorize",
"token_endpoint": "` + "http://" + r.Host + `/token",
"registration_endpoint": "` + "http://" + r.Host + `/register",
"response_types_supported": ["code"],
"scopes_supported": ["read", "write"]
}`))
case "/register":
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
// Decode the registration body and capture redirect_uris.
var body map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
if uris, ok := body["redirect_uris"].([]interface{}); ok && len(uris) > 0 {
if uri, ok := uris[0].(string); ok {
registeredRedirectURI <- uri
}
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{
"client_id": "test-client-id",
"client_secret": "test-client-secret"
}`))
default:
http.NotFound(w, r)
}
}))
t.Cleanup(authServer.Close)
// Stand up a mock MCP server that returns RFC 9728 Protected
// Resource Metadata pointing to the auth server.
mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/oauth-protected-resource" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"resource": "` + "http://" + r.Host + `",
"authorization_servers": ["` + authServer.URL + `"]
}`))
return
}
http.NotFound(w, r)
}))
t.Cleanup(mcpServer.Close)
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
// Create config with auth_type=oauth2 but no OAuth2 fields to
// trigger auto-discovery and dynamic client registration.
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Redirect URI Test",
Slug: "redirect-uri-test",
Transport: "streamable_http",
URL: mcpServer.URL + "/v1/mcp",
AuthType: "oauth2",
Availability: "default_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
require.Equal(t, "test-client-id", created.OAuth2ClientID)
require.True(t, created.HasOAuth2Secret)
// The registration request has already completed by the time
// CreateMCPServerConfig returns, so the URI is in the channel.
var redirectURI string
select {
case redirectURI = <-registeredRedirectURI:
case <-ctx.Done():
t.Fatal("timed out waiting for registration redirect URI")
}
// Core assertion: the redirect URI must NOT contain the
// literal placeholder "{id}". Before the fix the callback
// URL was built before the database insert, so it had
// "{id}" where the UUID should be.
require.NotContains(t, redirectURI, "{id}",
"redirect URI sent during registration must not contain the literal \"{id}\" placeholder")
// Verify the redirect URI contains the real config UUID that
// was assigned by the database.
require.Contains(t, redirectURI, created.ID.String(),
"redirect URI should contain the actual config UUID")
// Sanity-check the full path structure.
require.Contains(t, redirectURI,
"/api/experimental/mcp/servers/"+created.ID.String()+"/oauth2/callback",
"redirect URI should have the expected callback path")
// Double-check that the ID segment is a valid UUID (not some
// other placeholder or malformed value).
pathParts := strings.Split(redirectURI, "/")
var foundUUID bool
for _, part := range pathParts {
if _, err := uuid.Parse(part); err == nil {
foundUUID = true
require.Equal(t, created.ID.String(), part,
"UUID in redirect URI path should match created config ID")
break
}
}
require.True(t, foundUUID,
"redirect URI path should contain a valid UUID segment")
})
t.Run("PartialOAuth2FieldsRejected", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
// Provide client_id but omit auth_url and token_url.
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Partial Fields",
Slug: "partial-oauth2",
Transport: "streamable_http",
URL: "https://mcp.example.com/partial",
AuthType: "oauth2",
OAuth2ClientID: "only-client-id",
Availability: "default_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "automatic discovery")
})
t.Run("DiscoveryFailure", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// MCP server that returns 404 for the well-known endpoint and
// a non-401 status for the root — discovery has nothing to latch
// onto.
mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
}))
t.Cleanup(mcpServer.Close)
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Will Fail",
Slug: "discovery-fail",
Transport: "streamable_http",
URL: mcpServer.URL + "/v1/mcp",
AuthType: "oauth2",
Availability: "default_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "auto-discovery failed")
})
t.Run("ManualConfigStillWorks", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
// Providing all three OAuth2 fields bypasses discovery entirely.
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Manual Config",
Slug: "manual-oauth2",
Transport: "streamable_http",
URL: "https://mcp.example.com/manual",
AuthType: "oauth2",
OAuth2ClientID: "manual-client-id",
OAuth2AuthURL: "https://auth.example.com/authorize",
OAuth2TokenURL: "https://auth.example.com/token",
Availability: "default_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
require.Equal(t, "manual-client-id", created.OAuth2ClientID)
require.Equal(t, "https://auth.example.com/authorize", created.OAuth2AuthURL)
require.Equal(t, "https://auth.example.com/token", created.OAuth2TokenURL)
})
}
func TestChatWithMCPServerIDs(t *testing.T) {
t.Parallel()
@@ -437,14 +741,16 @@ func TestChatWithMCPServerIDs(t *testing.T) {
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
// Create the chat model config required for creating a chat.
_ = createChatModelConfigForMCP(t, client)
_ = createChatModelConfigForMCP(t, expClient)
// Create an enabled MCP server config.
mcpConfig := createMCPServerConfig(t, client, "chat-mcp-server", true)
// Create a chat referencing the MCP server.
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
@@ -458,7 +764,7 @@ func TestChatWithMCPServerIDs(t *testing.T) {
require.Contains(t, chat.MCPServerIDs, mcpConfig.ID)
// Fetch the chat and verify the MCP server IDs persist.
fetched, err := client.GetChat(ctx, chat.ID)
fetched, err := expClient.GetChat(ctx, chat.ID)
require.NoError(t, err)
require.Contains(t, fetched.MCPServerIDs, mcpConfig.ID)
}
@@ -466,7 +772,7 @@ func TestChatWithMCPServerIDs(t *testing.T) {
// createChatModelConfigForMCP sets up a chat provider and model
// config so that CreateChat succeeds. This mirrors the helper in
// chats_test.go but is defined here to avoid coupling.
func createChatModelConfigForMCP(t testing.TB, client *codersdk.Client) codersdk.ChatModelConfig {
func createChatModelConfigForMCP(t testing.TB, client *codersdk.ExperimentalClient) codersdk.ChatModelConfig {
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
+44 -12
View File
@@ -242,27 +242,51 @@ func (api *API) listMembers(rw http.ResponseWriter, r *http.Request) {
// @Produce json
// @Tags Members
// @Param organization path string true "Organization ID"
// @Param q query string false "Member search query"
// @Param after_id query string false "After ID" format(uuid)
// @Param limit query int false "Page limit, if 0 returns all members"
// @Param offset query int false "Page offset"
// @Success 200 {object} []codersdk.PaginatedMembersResponse
// @Router /organizations/{organization}/paginated-members [get]
func (api *API) paginatedMembers(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
organization = httpmw.OrganizationParam(r)
paginationParams, ok = ParsePagination(rw, r)
ctx = r.Context()
organization = httpmw.OrganizationParam(r)
)
filterQuery := r.URL.Query().Get("q")
userFilterParams, filterErrs := searchquery.Users(filterQuery)
if len(filterErrs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid member search query.",
Validations: filterErrs,
})
return
}
paginationParams, ok := ParsePagination(rw, r)
if !ok {
return
}
paginatedMemberRows, err := api.Database.PaginatedOrganizationMembers(ctx, database.PaginatedOrganizationMembersParams{
OrganizationID: organization.ID,
IncludeSystem: false,
// #nosec G115 - Pagination limits are small and fit in int32
LimitOpt: int32(paginationParams.Limit),
AfterID: paginationParams.AfterID,
OrganizationID: organization.ID,
IncludeSystem: false,
Search: userFilterParams.Search,
Name: userFilterParams.Name,
Status: userFilterParams.Status,
RbacRole: userFilterParams.RbacRole,
LastSeenBefore: userFilterParams.LastSeenBefore,
LastSeenAfter: userFilterParams.LastSeenAfter,
CreatedAfter: userFilterParams.CreatedAfter,
CreatedBefore: userFilterParams.CreatedBefore,
GithubComUserID: userFilterParams.GithubComUserID,
LoginType: userFilterParams.LoginType,
// #nosec G115 - Pagination offsets are small and fit in int32
OffsetOpt: int32(paginationParams.Offset),
// #nosec G115 - Pagination limits are small and fit in int32
LimitOpt: int32(paginationParams.Limit),
})
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
@@ -273,18 +297,21 @@ func (api *API) paginatedMembers(rw http.ResponseWriter, r *http.Request) {
return
}
memberRows := make([]database.OrganizationMembersRow, 0)
for _, pRow := range paginatedMemberRows {
row := database.OrganizationMembersRow{
memberRows := make([]database.OrganizationMembersRow, len(paginatedMemberRows))
for i, pRow := range paginatedMemberRows {
memberRows[i] = database.OrganizationMembersRow{
OrganizationMember: pRow.OrganizationMember,
Username: pRow.Username,
AvatarURL: pRow.AvatarURL,
Name: pRow.Name,
Email: pRow.Email,
GlobalRoles: pRow.GlobalRoles,
LastSeenAt: pRow.LastSeenAt,
Status: pRow.Status,
LoginType: pRow.LoginType,
UserCreatedAt: pRow.UserCreatedAt,
UserUpdatedAt: pRow.UserUpdatedAt,
}
memberRows = append(memberRows, row)
}
if len(paginatedMemberRows) == 0 {
@@ -501,6 +528,11 @@ func convertOrganizationMembersWithUserData(ctx context.Context, db database.Sto
Name: rows[i].Name,
Email: rows[i].Email,
GlobalRoles: db2sdk.SlimRolesFromNames(rows[i].GlobalRoles),
LastSeenAt: rows[i].LastSeenAt,
Status: codersdk.UserStatus(rows[i].Status),
LoginType: codersdk.LoginType(rows[i].LoginType),
UserCreatedAt: rows[i].UserCreatedAt,
UserUpdatedAt: rows[i].UserUpdatedAt,
OrganizationMember: convertedMembers[i],
})
}
+63
View File
@@ -1,12 +1,14 @@
package coderd_test
import (
"context"
"database/sql"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
@@ -132,6 +134,67 @@ func TestListMembers(t *testing.T) {
})
}
func TestGetOrgMembersFilter(t *testing.T) {
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
OIDCConfig: &coderd.OIDCConfig{
AllowSignups: true,
},
})
first := coderdtest.CreateFirstUser(t, client)
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
res, err := client.OrganizationMembersPaginated(testCtx, first.OrganizationID, req)
require.NoError(t, err)
reduced := make([]codersdk.ReducedUser, len(res.Members))
for i, user := range res.Members {
reduced[i] = orgMemberToReducedUser(user)
}
return reduced
})
}
func TestGetOrgMembersPagination(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coderdtest.UsersPagination(ctx, t, client, nil, func(req codersdk.UsersRequest) ([]codersdk.ReducedUser, int) {
res, err := client.OrganizationMembersPaginated(ctx, first.OrganizationID, req)
require.NoError(t, err)
reduced := make([]codersdk.ReducedUser, len(res.Members))
for i, user := range res.Members {
reduced[i] = orgMemberToReducedUser(user)
}
return reduced, res.Count
})
}
func onlyIDs(u codersdk.OrganizationMemberWithUserData) uuid.UUID {
return u.UserID
}
func orgMemberToReducedUser(user codersdk.OrganizationMemberWithUserData) codersdk.ReducedUser {
return codersdk.ReducedUser{
MinimalUser: codersdk.MinimalUser{
ID: user.UserID,
Username: user.Username,
Name: user.Name,
AvatarURL: user.AvatarURL,
},
Email: user.Email,
CreatedAt: user.UserCreatedAt,
UpdatedAt: user.UserUpdatedAt,
LastSeenAt: user.LastSeenAt,
Status: user.Status,
LoginType: user.LoginType,
}
}
+2 -3
View File
@@ -18,7 +18,6 @@ import (
"path/filepath"
"regexp"
"slices"
"sort"
"strings"
"sync"
"testing"
@@ -549,8 +548,8 @@ func TestExpiredLeaseIsRequeued(t *testing.T) {
leasedIDs = append(leasedIDs, msg.ID.String())
}
sort.Strings(msgs)
sort.Strings(leasedIDs)
slices.Sort(msgs)
slices.Sort(leasedIDs)
require.EqualValues(t, msgs, leasedIDs)
// Wait out the lease period; all messages should be eligible to be re-acquired.
+16 -32
View File
@@ -18,12 +18,13 @@ import (
func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("RedirectURIValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
redirectURIs []string
@@ -132,9 +133,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("ClientURIValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
clientURI string
@@ -207,9 +205,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("LogoURIValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
logoURI string
@@ -272,9 +267,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("GrantTypeValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
grantTypes []codersdk.OAuth2ProviderGrantType
@@ -347,9 +339,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("ResponseTypeValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
responseTypes []codersdk.OAuth2ProviderResponseType
@@ -407,9 +396,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("TokenEndpointAuthMethodValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
authMethod codersdk.OAuth2TokenEndpointAuthMethod
@@ -479,6 +465,10 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
func TestOAuth2ClientNameValidation(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
clientName string
@@ -530,8 +520,6 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
req := codersdk.OAuth2ClientRegistrationRequest{
@@ -554,6 +542,10 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
func TestOAuth2ClientScopeValidation(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
scope string
@@ -615,8 +607,6 @@ func TestOAuth2ClientScopeValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
req := codersdk.OAuth2ClientRegistrationRequest{
@@ -682,11 +672,13 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("ExtremelyLongRedirectURI", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
// Create a very long but valid HTTPS URI
@@ -709,8 +701,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Run("ManyRedirectURIs", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
// Test with many redirect URIs
@@ -732,8 +722,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Run("URIWithUnusualPort", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
req := codersdk.OAuth2ClientRegistrationRequest{
@@ -748,8 +736,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Run("URIWithComplexPath", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
req := codersdk.OAuth2ClientRegistrationRequest{
@@ -764,8 +750,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Run("URIWithEncodedCharacters", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
// Test with URL-encoded characters
+16 -32
View File
@@ -18,12 +18,13 @@ import (
func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("RedirectURIValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
redirectURIs []string
@@ -132,9 +133,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("ClientURIValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
clientURI string
@@ -207,9 +205,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("LogoURIValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
logoURI string
@@ -272,9 +267,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("GrantTypeValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
grantTypes []codersdk.OAuth2ProviderGrantType
@@ -347,9 +339,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("ResponseTypeValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
responseTypes []codersdk.OAuth2ProviderResponseType
@@ -407,9 +396,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
t.Run("TokenEndpointAuthMethodValidation", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
authMethod codersdk.OAuth2TokenEndpointAuthMethod
@@ -479,6 +465,10 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
func TestOAuth2ClientNameValidation(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
clientName string
@@ -530,8 +520,6 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
req := codersdk.OAuth2ClientRegistrationRequest{
@@ -554,6 +542,10 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
func TestOAuth2ClientScopeValidation(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
tests := []struct {
name string
scope string
@@ -615,8 +607,6 @@ func TestOAuth2ClientScopeValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
req := codersdk.OAuth2ClientRegistrationRequest{
@@ -682,11 +672,13 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("ExtremelyLongRedirectURI", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
// Create a very long but valid HTTPS URI
@@ -709,8 +701,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Run("ManyRedirectURIs", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
// Test with many redirect URIs
@@ -732,8 +722,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Run("URIWithUnusualPort", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
req := codersdk.OAuth2ClientRegistrationRequest{
@@ -748,8 +736,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Run("URIWithComplexPath", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
req := codersdk.OAuth2ClientRegistrationRequest{
@@ -764,8 +750,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
t.Run("URIWithEncodedCharacters", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
// Test with URL-encoded characters
+2 -1
View File
@@ -1,6 +1,7 @@
package prometheusmetrics_test
import (
"slices"
"sort"
"testing"
@@ -134,7 +135,7 @@ func collectAndSortMetrics(t *testing.T, collector prometheus.Collector, count i
// Ensure always the same order of metrics
sort.Slice(metrics, func(i, j int) bool {
return sort.StringsAreSorted([]string{metrics[i].Label[0].GetValue(), metrics[j].Label[1].GetValue()})
return slices.IsSorted([]string{metrics[i].Label[0].GetValue(), metrics[j].Label[1].GetValue()})
})
return metrics
}
+19 -2
View File
@@ -316,13 +316,16 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
denyPermissions...,
),
User: append(
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage),
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage, ResourceAibridgeInterception),
Permissions(map[string][]policy.Action{
// Users cannot do create/update/delete on themselves, but they
// can read their own details.
ResourceUser.Type: {policy.ActionRead, policy.ActionReadPersonal, policy.ActionUpdatePersonal},
// Users can create provisioner daemons scoped to themselves.
ResourceProvisionerDaemon.Type: {policy.ActionRead, policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
// Members can create and update AI Bridge interceptions but
// cannot read them back.
ResourceAibridgeInterception.Type: {policy.ActionCreate, policy.ActionUpdate},
})...,
),
ByOrgID: map[string]OrgPermissions{},
@@ -345,7 +348,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
// Allow auditors to query deployment stats and insights.
ResourceDeploymentStats.Type: {policy.ActionRead},
ResourceDeploymentConfig.Type: {policy.ActionRead},
// Allow auditors to query aibridge interceptions.
// Allow auditors to query AI Bridge interceptions.
ResourceAibridgeInterception.Type: {policy.ActionRead},
}),
User: []Permission{},
@@ -998,6 +1001,7 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
ResourcePrebuiltWorkspace,
ResourceUser,
ResourceOrganizationMember,
ResourceAibridgeInterception,
),
Permissions(map[string][]policy.Action{
// Reduced permission set on dormant workspaces. No build,
@@ -1016,6 +1020,12 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
ResourceOrganizationMember.Type: {
policy.ActionRead,
},
// Members can create and update AI Bridge interceptions but
// cannot read them back.
ResourceAibridgeInterception.Type: {
policy.ActionCreate,
policy.ActionUpdate,
},
})...,
)
@@ -1073,6 +1083,7 @@ func OrgServiceAccountPermissions(org OrgSettings) OrgRolePermissions {
ResourcePrebuiltWorkspace,
ResourceUser,
ResourceOrganizationMember,
ResourceAibridgeInterception,
),
Permissions(map[string][]policy.Action{
// Reduced permission set on dormant workspaces. No build,
@@ -1091,6 +1102,12 @@ func OrgServiceAccountPermissions(org OrgSettings) OrgRolePermissions {
ResourceOrganizationMember.Type: {
policy.ActionRead,
},
// Service accounts can create and update AI Bridge
// interceptions but cannot read them back.
ResourceAibridgeInterception.Type: {
policy.ActionCreate,
policy.ActionUpdate,
},
})...,
)
+19 -2
View File
@@ -1023,8 +1023,9 @@ func TestRolePermissions(t *testing.T) {
},
},
{
Name: "AIBridgeInterceptions",
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
// Members can create/update records but can't read them afterwards.
Name: "AIBridgeInterceptionsCreateUpdate",
Actions: []policy.Action{policy.ActionCreate, policy.ActionUpdate},
Resource: rbac.ResourceAibridgeInterception.WithOwner(currentUser.String()),
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner, memberMe},
@@ -1036,6 +1037,22 @@ func TestRolePermissions(t *testing.T) {
},
},
},
{
// Only owners and site-wide auditors can view interceptions and their sub-resources.
Name: "AIBridgeInterceptionsRead",
Actions: []policy.Action{policy.ActionRead},
Resource: rbac.ResourceAibridgeInterception.WithOwner(currentUser.String()),
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner, auditor},
false: {
memberMe,
orgAdmin, otherOrgAdmin,
orgAuditor, otherOrgAuditor,
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
userAdmin, orgUserAdmin, otherOrgUserAdmin,
},
},
},
{
Name: "BoundaryUsage",
Actions: []policy.Action{policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
+1 -2
View File
@@ -3,7 +3,6 @@ package rbac
import (
"fmt"
"slices"
"sort"
"strings"
"github.com/google/uuid"
@@ -176,7 +175,7 @@ func CompositeScopeNames() []string {
for k := range compositePerms {
out = append(out, string(k))
}
sort.Strings(out)
slices.Sort(out)
return out
}
+2 -1
View File
@@ -40,7 +40,8 @@ var externalLowLevel = map[ScopeName]struct{}{
"file:create": {},
"file:*": {},
// Users (personal profile only)
// Users
"user:read": {},
"user:read_personal": {},
"user:update_personal": {},
"user.*": {},
+3 -2
View File
@@ -1,7 +1,7 @@
package rbac
import (
"sort"
"slices"
"strings"
"testing"
@@ -16,7 +16,7 @@ func TestExternalScopeNames(t *testing.T) {
// Ensure sorted ascending
sorted := append([]string(nil), names...)
sort.Strings(sorted)
slices.Sort(sorted)
require.Equal(t, sorted, names)
// Ensure each entry expands to site-only
@@ -62,6 +62,7 @@ func TestIsExternalScope(t *testing.T) {
require.True(t, IsExternalScope("template:use"))
require.True(t, IsExternalScope("workspace:*"))
require.True(t, IsExternalScope("coder:workspaces.create"))
require.True(t, IsExternalScope("user:read"))
require.False(t, IsExternalScope("debug_info:read")) // internal-only
require.False(t, IsExternalScope("unknown:read"))
}
+43
View File
@@ -401,6 +401,49 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string,
return filter, parser.Errors
}
func AIBridgeSessions(ctx context.Context, db database.Store, query string, page codersdk.Pagination, actorID uuid.UUID, afterSessionID string) (database.ListAIBridgeSessionsParams, []codersdk.ValidationError) {
// nolint:exhaustruct // Empty values just means "don't filter by that field".
filter := database.ListAIBridgeSessionsParams{
AfterSessionID: afterSessionID,
// #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range
Limit: int32(page.Limit),
// #nosec G115 - Safe conversion for pagination offset which is expected to be within int32 range
Offset: int32(page.Offset),
}
if query == "" {
return filter, nil
}
values, errors := searchTerms(query, func(string, url.Values) error {
// Do not specify a default search key; let's be explicit to prevent user confusion.
return xerrors.New("no search key specified")
})
if len(errors) > 0 {
return filter, errors
}
parser := httpapi.NewQueryParamParser()
filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID)
filter.Provider = parser.String(values, "", "provider")
filter.Model = parser.String(values, "", "model")
filter.Client = parser.String(values, "", "client")
filter.SessionID = parser.String(values, "", "session_id")
// Time must be between started_after and started_before.
filter.StartedAfter = parser.Time3339Nano(values, time.Time{}, "started_after")
filter.StartedBefore = parser.Time3339Nano(values, time.Time{}, "started_before")
if !filter.StartedBefore.IsZero() && !filter.StartedAfter.IsZero() && !filter.StartedBefore.After(filter.StartedAfter) {
parser.Errors = append(parser.Errors, codersdk.ValidationError{
Field: "started_before",
Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`,
})
}
parser.ErrorExcessParams(values)
return filter, parser.Errors
}
func AIBridgeModels(query string, page codersdk.Pagination) (database.ListAIBridgeModelsParams, []codersdk.ValidationError) {
// nolint:exhaustruct // Empty values just means "don't filter by that field".
filter := database.ListAIBridgeModelsParams{
+16 -27
View File
@@ -1272,10 +1272,14 @@ func TestTemplateVersionsByTemplate(t *testing.T) {
func TestTemplateVersionByName(t *testing.T) {
t.Parallel()
// Single instance shared across all sub-tests. Each sub-test
// creates its own template version and template with unique
// IDs so parallel execution is safe.
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
@@ -1290,8 +1294,6 @@ func TestTemplateVersionByName(t *testing.T) {
t.Run("Found", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
@@ -1935,10 +1937,12 @@ func TestPaginatedTemplateVersions(t *testing.T) {
func TestTemplateVersionByOrganizationTemplateAndName(t *testing.T) {
t.Parallel()
// Shared instance — see TestTemplateVersionByName for rationale.
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
@@ -1953,8 +1957,6 @@ func TestTemplateVersionByOrganizationTemplateAndName(t *testing.T) {
t.Run("Found", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
@@ -2204,10 +2206,14 @@ func TestTemplateVersionVariables(t *testing.T) {
func TestTemplateVersionPatch(t *testing.T) {
t.Parallel()
// Single instance shared across all 9 sub-tests. Each sub-test
// creates its own template version(s) and template(s) with
// unique IDs so parallel execution is safe.
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
t.Run("Update the name", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
@@ -2226,8 +2232,6 @@ func TestTemplateVersionPatch(t *testing.T) {
t.Run("Update the message", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(req *codersdk.CreateTemplateVersionRequest) {
req.Message = "Example message"
})
@@ -2247,8 +2251,6 @@ func TestTemplateVersionPatch(t *testing.T) {
t.Run("Remove the message", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(req *codersdk.CreateTemplateVersionRequest) {
req.Message = "Example message"
})
@@ -2268,8 +2270,6 @@ func TestTemplateVersionPatch(t *testing.T) {
t.Run("Keep the message", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
wantMessage := "Example message"
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(req *codersdk.CreateTemplateVersionRequest) {
req.Message = wantMessage
@@ -2291,8 +2291,6 @@ func TestTemplateVersionPatch(t *testing.T) {
t.Run("Use the same name if a new name is not passed", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
@@ -2306,9 +2304,6 @@ func TestTemplateVersionPatch(t *testing.T) {
t.Run("Use the same name for two different templates", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
coderdtest.CreateTemplate(t, client, user.OrganizationID, version1.ID)
version2 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
@@ -2334,8 +2329,6 @@ func TestTemplateVersionPatch(t *testing.T) {
t.Run("Use the same name for two versions for the same templates", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(ctvr *codersdk.CreateTemplateVersionRequest) {
ctvr.Name = "v1"
})
@@ -2356,8 +2349,6 @@ func TestTemplateVersionPatch(t *testing.T) {
t.Run("Rename the unassigned template", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@@ -2373,8 +2364,6 @@ func TestTemplateVersionPatch(t *testing.T) {
t.Run("Use incorrect template version name", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
+3 -40
View File
@@ -7,7 +7,7 @@ import (
"fmt"
"net/http"
"net/mail"
"sort"
"slices"
"strconv"
"strings"
"sync"
@@ -744,43 +744,6 @@ func (api *API) postLogout(rw http.ResponseWriter, r *http.Request) {
})
}
// @Summary Set session token cookie
// @Description Converts the current session token into a Set-Cookie response.
// @Description This is used by embedded iframes (e.g. VS Code chat) that
// @Description receive a session token out-of-band via postMessage but need
// @Description cookie-based auth for WebSocket connections.
// @ID set-session-token-cookie
// @Security CoderSessionToken
// @Tags Authorization
// @Success 204
// @Router /users/me/session/token-to-cookie [post]
// @x-apidocgen {"skip": true}
func (api *API) postSessionTokenCookie(rw http.ResponseWriter, r *http.Request) {
// Only accept the token from the Coder-Session-Token header.
// Other sources (query params, cookies) should not be allowed
// to bootstrap a new cookie.
token := r.Header.Get(codersdk.SessionTokenHeader)
if token == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Session token must be provided via the Coder-Session-Token header.",
})
return
}
apiKey := httpmw.APIKey(r)
cookie := api.DeploymentValues.HTTPCookies.Apply(&http.Cookie{
Name: codersdk.SessionTokenCookie,
Value: token,
Path: "/",
HttpOnly: true,
// Expire the cookie when the underlying API key expires.
Expires: apiKey.ExpiresAt,
})
http.SetCookie(rw, cookie)
rw.WriteHeader(http.StatusNoContent)
}
// GithubOAuth2Team represents a team scoped to an organization.
type GithubOAuth2Team struct {
Organization string
@@ -1626,7 +1589,7 @@ func claimFields(claims map[string]interface{}) []string {
for field := range claims {
fields = append(fields, field)
}
sort.Strings(fields)
slices.Sort(fields)
return fields
}
@@ -1639,7 +1602,7 @@ func blankFields(claims map[string]interface{}) []string {
fields = append(fields, field)
}
}
sort.Strings(fields)
slices.Sort(fields)
return fields
}
+9 -28
View File
@@ -17,7 +17,6 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/agentconnectionbatcher"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
@@ -258,7 +257,6 @@ func (api *API) startAgentYamuxMonitor(ctx context.Context,
db: api.Database,
replicaID: api.ID,
updater: api,
connectionBatcher: api.connectionBatcher,
disconnectTimeout: api.AgentInactiveDisconnectTimeout,
logger: api.Logger.With(
slog.F("workspace_id", workspaceBuild.WorkspaceID),
@@ -294,8 +292,6 @@ type agentConnectionMonitor struct {
logger slog.Logger
pingPeriod time.Duration
connectionBatcher *agentconnectionbatcher.Batcher
// state manipulated by both sendPings() and monitor() goroutines: needs to be threadsafe
lastPing atomic.Pointer[time.Time]
@@ -458,32 +454,17 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) {
Valid: true,
}
if m.connectionBatcher != nil {
m.connectionBatcher.Add(agentconnectionbatcher.Update{
ID: m.workspaceAgent.ID,
FirstConnectedAt: m.firstConnectedAt,
LastConnectedAt: m.lastConnectedAt,
DisconnectedAt: m.disconnectedAt,
UpdatedAt: dbtime.Now(),
LastConnectedReplicaID: uuid.NullUUID{
UUID: m.replicaID,
Valid: true,
},
})
} else {
err = m.updateConnectionTimes(ctx)
if err != nil {
reason = err.Error()
if !database.IsQueryCanceledError(err) {
m.logger.Error(ctx, "failed to update agent connection times", slog.Error(err))
}
return
err = m.updateConnectionTimes(ctx)
if err != nil {
reason = err.Error()
if !database.IsQueryCanceledError(err) {
m.logger.Error(ctx, "failed to update agent connection times", slog.Error(err))
}
return
}
// We don't need to publish a workspace update here because we
// published an update when the workspace first connected. Since
// all we've done is updated lastConnectedAt, the workspace is
// still connected and hasn't changed status.
// we don't need to publish a workspace update here because we published an update when the workspace first
// connected. Since all we've done is updated lastConnectedAt, the workspace is still connected and hasn't
// changed status. We don't expect to get updates just for the times changing.
ctx, err := dbauthz.WithWorkspaceRBAC(ctx, m.workspace.RBACObject())
if err != nil {
@@ -6,8 +6,8 @@ import (
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatcost"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
"github.com/coder/coder/v2/codersdk"
)
+215 -14
View File
@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
@@ -20,12 +21,6 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/chatd/chatcost"
"github.com/coder/coder/v2/coderd/chatd/chatloop"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/coderd/chatd/mcpclient"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -34,6 +29,12 @@ import (
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
@@ -1199,6 +1200,7 @@ type chatMessage struct {
contextLimit int64
totalCostMicros int64
runtimeMs int64
providerResponseID string
}
func newChatMessage(
@@ -1255,6 +1257,101 @@ func (m chatMessage) withRuntimeMs(ms int64) chatMessage {
return m
}
func (m chatMessage) withProviderResponseID(id string) chatMessage {
m.providerResponseID = id
return m
}
// chainModeInfo holds the information needed to determine whether
// a follow-up turn can use OpenAI's previous_response_id chaining
// instead of replaying full conversation history.
type chainModeInfo struct {
// previousResponseID is the provider response ID from the last
// assistant message, if any.
previousResponseID string
// modelConfigID is the model configuration used to produce the
// assistant message referenced by previousResponseID.
modelConfigID uuid.UUID
// trailingUserCount is the number of contiguous user messages
// at the end of the conversation that form the current turn.
trailingUserCount int
}
// resolveChainMode scans DB messages from the end to count trailing user
// messages for the current turn and detect whether the immediately
// preceding assistant/tool block can chain from a provider response ID.
func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
var info chainModeInfo
i := len(messages) - 1
for ; i >= 0; i-- {
if messages[i].Role == database.ChatMessageRoleUser {
info.trailingUserCount++
continue
}
break
}
for ; i >= 0; i-- {
switch messages[i].Role {
case database.ChatMessageRoleAssistant:
if messages[i].ProviderResponseID.Valid &&
messages[i].ProviderResponseID.String != "" {
info.previousResponseID = messages[i].ProviderResponseID.String
if messages[i].ModelConfigID.Valid {
info.modelConfigID = messages[i].ModelConfigID.UUID
}
return info
}
return info
case database.ChatMessageRoleTool:
continue
default:
return info
}
}
return info
}
// filterPromptForChainMode keeps only system messages and the last
// trailingUserCount user messages from the prompt. Assistant and tool
// messages are dropped because the provider already has them via the
// previous_response_id chain.
func filterPromptForChainMode(
prompt []fantasy.Message,
trailingUserCount int,
) []fantasy.Message {
if trailingUserCount <= 0 {
return prompt
}
totalUsers := 0
for _, msg := range prompt {
if msg.Role == "user" {
totalUsers++
}
}
usersToSkip := totalUsers - trailingUserCount
if usersToSkip < 0 {
usersToSkip = 0
}
filtered := make([]fantasy.Message, 0, len(prompt))
usersSeen := 0
for _, msg := range prompt {
switch msg.Role {
case "system":
filtered = append(filtered, msg)
case "user":
usersSeen++
if usersSeen > usersToSkip {
filtered = append(filtered, msg)
}
}
}
return filtered
}
// appendChatMessage appends a single message to the batch insert params.
func appendChatMessage(
params *database.InsertChatMessagesParams,
@@ -1276,6 +1373,7 @@ func appendChatMessage(
params.Compressed = append(params.Compressed, msg.compressed)
params.TotalCostMicros = append(params.TotalCostMicros, msg.totalCostMicros)
params.RuntimeMs = append(params.RuntimeMs, msg.runtimeMs)
params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID)
}
func insertUserMessageAndSetPending(
@@ -2823,6 +2921,7 @@ func (p *Server) runChat(
if err := g.Wait(); err != nil {
return result, err
}
chainInfo := resolveChainMode(messages)
result.PushSummaryModel = model
result.ProviderKeys = providerKeys
// Fire title generation asynchronously so it doesn't block the
@@ -3092,7 +3191,8 @@ func (p *Server) runChat(
reasoningTokens, cacheCreationTokens, cacheReadTokens,
).withContextLimit(contextLimit).
withTotalCostMicros(totalCostVal).
withRuntimeMs(runtimeMs))
withRuntimeMs(runtimeMs).
withProviderResponseID(step.ProviderResponseID))
}
for _, resultContent := range toolResultContents {
@@ -3150,8 +3250,14 @@ func (p *Server) runChat(
// "Summarizing..." tool call with the "Summarized" tool
// result.
compactionToolCallID := "chat_summarized_" + uuid.NewString()
effectiveThreshold := modelConfig.CompressionThreshold
thresholdSource := "model_default"
if override, ok := p.resolveUserCompactionThreshold(ctx, chat.OwnerID, modelConfig.ID); ok {
effectiveThreshold = override
thresholdSource = "user_override"
}
compactionOptions := &chatloop.CompactionOptions{
ThresholdPercent: modelConfig.CompressionThreshold,
ThresholdPercent: effectiveThreshold,
ContextLimit: modelConfig.ContextLimit,
Persist: func(
persistCtx context.Context,
@@ -3168,6 +3274,7 @@ func (p *Server) runChat(
}
logger.Info(persistCtx, "chat context summarized",
slog.F("chat_id", chat.ID),
slog.F("threshold_source", thresholdSource),
slog.F("threshold_percent", result.ThresholdPercent),
slog.F("usage_percent", result.UsagePercent),
slog.F("context_tokens", result.ContextTokens),
@@ -3227,6 +3334,7 @@ func (p *Server) runChat(
// create workspaces or spawn further subagents — they should
// focus on completing their delegated task.
if !chat.ParentChatID.Valid {
// Workspace provisioning tools.
tools = append(tools,
chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: p.db,
@@ -3254,6 +3362,37 @@ func (p *Server) runChat(
WorkspaceMu: &workspaceMu,
}),
)
// Plan presentation tool.
tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{
GetWorkspaceConn: workspaceCtx.getWorkspaceConn,
StoreFile: func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error) {
workspaceCtx.chatStateMu.Lock()
chatSnapshot := *workspaceCtx.currentChat
workspaceCtx.chatStateMu.Unlock()
if !chatSnapshot.WorkspaceID.Valid {
return uuid.Nil, xerrors.New("chat has no workspace")
}
ws, err := p.db.GetWorkspaceByID(ctx, chatSnapshot.WorkspaceID.UUID)
if err != nil {
return uuid.Nil, xerrors.Errorf("resolve workspace: %w", err)
}
row, err := p.db.InsertChatFile(ctx, database.InsertChatFileParams{
OwnerID: chatSnapshot.OwnerID,
OrganizationID: ws.OrganizationID,
Name: name,
Mimetype: mediaType,
Data: data,
})
if err != nil {
return uuid.Nil, xerrors.Errorf("insert chat file: %w", err)
}
return row.ID, nil
},
}))
tools = append(tools, p.subagentTools(ctx, func() database.Chat {
return chat
})...)
@@ -3272,24 +3411,49 @@ func (p *Server) runChat(
}
if isComputerUse {
desktopGeometry := workspacesdk.DefaultDesktopGeometry()
providerTools = append(providerTools, chatloop.ProviderTool{
Definition: chattool.ComputerUseProviderTool(
workspacesdk.DesktopDisplayWidth,
workspacesdk.DesktopDisplayHeight),
desktopGeometry.DeclaredWidth,
desktopGeometry.DeclaredHeight,
),
Runner: chattool.NewComputerUseTool(
workspacesdk.DesktopDisplayWidth,
workspacesdk.DesktopDisplayHeight,
workspaceCtx.getWorkspaceConn, quartz.NewReal(),
desktopGeometry.DeclaredWidth,
desktopGeometry.DeclaredHeight,
workspaceCtx.getWorkspaceConn,
quartz.NewReal(),
),
})
}
providerOptions := chatprovider.ProviderOptionsFromChatModelConfig(
model,
callConfig.ProviderOptions,
)
// When the OpenAI Responses API has store=true, the provider
// retains conversation history server-side. For follow-up turns,
// we set previous_response_id and send only system instructions
// plus the new user input, avoiding redundant replay of prior
// assistant and tool messages that the provider already has.
chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) &&
chainInfo.previousResponseID != "" &&
chainInfo.trailingUserCount > 0 &&
chainInfo.modelConfigID == modelConfig.ID
if chainModeActive {
providerOptions = chatprovider.CloneWithPreviousResponseID(
providerOptions,
chainInfo.previousResponseID,
)
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
}
err = chatloop.Run(ctx, chatloop.RunOptions{
Model: model,
Messages: prompt,
Tools: tools, MaxSteps: maxChatSteps,
ModelConfig: callConfig,
ProviderOptions: chatprovider.ProviderOptionsFromChatModelConfig(model, callConfig.ProviderOptions),
ProviderOptions: providerOptions,
ProviderTools: providerTools,
ContextLimitFallback: modelConfigContextLimit,
@@ -3337,8 +3501,17 @@ func (p *Server) runChat(
if reloadUserPrompt != "" {
reloadedPrompt = chatprompt.InsertSystem(reloadedPrompt, reloadUserPrompt)
}
if chainModeActive {
reloadedPrompt = filterPromptForChainMode(
reloadedPrompt,
chainInfo.trailingUserCount,
)
}
return reloadedPrompt, nil
},
DisableChainMode: func() {
chainModeActive = false
},
OnRetry: func(attempt int, retryErr error, delay time.Duration) {
if val, ok := p.chatStreams.Load(chat.ID); ok {
@@ -3715,6 +3888,34 @@ func (p *Server) resolveInstructions(
return instruction
}
// resolveUserCompactionThreshold looks up the user's per-model
// compaction threshold override. Returns the override value and
// true if one exists and is valid, or 0 and false otherwise.
func (p *Server) resolveUserCompactionThreshold(ctx context.Context, userID uuid.UUID, modelConfigID uuid.UUID) (int32, bool) {
raw, err := p.db.GetUserChatCompactionThreshold(ctx, database.GetUserChatCompactionThresholdParams{
UserID: userID,
Key: codersdk.CompactionThresholdKey(modelConfigID),
})
if errors.Is(err, sql.ErrNoRows) {
return 0, false
}
if err != nil {
p.logger.Warn(ctx, "failed to fetch compaction threshold override",
slog.F("user_id", userID),
slog.F("model_config_id", modelConfigID),
slog.Error(err),
)
return 0, false
}
// Range 0..100 must stay in sync with handler validation in
// coderd/chats.go.
val, err := strconv.ParseInt(raw, 10, 32)
if err != nil || val < 0 || val > 100 {
return 0, false
}
return int32(val), true
}
// resolveUserPrompt fetches the user's custom chat prompt from the
// database and wraps it in <user-instructions> tags. Returns empty
// string if no prompt is set.
@@ -2,6 +2,7 @@ package chatd
import (
"context"
"database/sql"
"sync"
"testing"
"time"
@@ -606,6 +607,85 @@ func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
}
func TestResolveUserCompactionThreshold(t *testing.T) {
t.Parallel()
userID := uuid.New()
modelConfigID := uuid.New()
expectedKey := codersdk.CompactionThresholdKey(modelConfigID)
tests := []struct {
name string
dbReturn string
dbErr error
wantVal int32
wantOK bool
wantWarnLog bool
}{
{
name: "NoRowsReturnsDefault",
dbErr: sql.ErrNoRows,
wantOK: false,
},
{
name: "ValidOverride",
dbReturn: "75",
wantVal: 75,
wantOK: true,
},
{
name: "OutOfRangeValue",
dbReturn: "101",
wantOK: false,
},
{
name: "NonIntegerValue",
dbReturn: "abc",
wantOK: false,
},
{
name: "UnexpectedDBError",
dbErr: xerrors.New("connection refused"),
wantOK: false,
wantWarnLog: true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockDB := dbmock.NewMockStore(ctrl)
sink := testutil.NewFakeSink(t)
srv := &Server{
db: mockDB,
logger: sink.Logger(),
}
mockDB.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), database.GetUserChatCompactionThresholdParams{
UserID: userID,
Key: expectedKey,
}).Return(tc.dbReturn, tc.dbErr)
val, ok := srv.resolveUserCompactionThreshold(context.Background(), userID, modelConfigID)
require.Equal(t, tc.wantVal, val)
require.Equal(t, tc.wantOK, ok)
warns := sink.Entries(func(e slog.SinkEntry) bool {
return e.Level == slog.LevelWarn
})
if tc.wantWarnLog {
require.NotEmpty(t, warns, "expected a warning log entry")
return
}
require.Empty(t, warns, "unexpected warning log entry")
})
}
}
// requireFieldValue asserts that a SinkEntry contains a field with
// the given name and value.
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
@@ -26,10 +26,6 @@ import (
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/chatd/chattest"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
@@ -41,6 +37,10 @@ import (
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
@@ -111,6 +111,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
agentToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
@@ -161,7 +162,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
)
})
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
@@ -170,7 +171,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
contextLimit := int64(4096)
isDefault := true
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
@@ -179,7 +180,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
require.NoError(t, err)
// Create a root chat whose first model call will spawn a subagent.
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
@@ -193,7 +194,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
// The root chat finishes first, then the chatd server
// picks up and runs the child (subagent) chat.
require.Eventually(t, func() bool {
got, getErr := client.GetChat(ctx, chat.ID)
got, getErr := expClient.GetChat(ctx, chat.ID)
if getErr != nil {
return false
}
@@ -217,7 +218,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
require.GreaterOrEqual(t, len(recorded), 2,
"expected at least 2 streamed LLM calls (root + subagent)")
workspaceTools := []string{"list_templates", "read_template", "create_workspace"}
workspaceTools := []string{"propose_plan", "list_templates", "read_template", "create_workspace"}
subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"}
// Identify root and subagent calls. Root chat calls include
@@ -901,15 +902,32 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
acquireTrap := clock.Trap().NewTicker("chatd", "acquire")
defer acquireTrap.Close()
assertPendingWithoutQueuedMessages := func(chatID uuid.UUID) {
t.Helper()
queued, dbErr := db.GetChatQueuedMessages(ctx, chatID)
require.NoError(t, dbErr)
require.Empty(t, queued)
fromDB, dbErr := db.GetChatByID(ctx, chatID)
require.NoError(t, dbErr)
require.Equal(t, database.ChatStatusPending, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
}
streamStarted := make(chan struct{})
interrupted := make(chan struct{})
secondRequestStarted := make(chan struct{})
thirdRequestStarted := make(chan struct{})
allowFinish := make(chan struct{})
var requestCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
if requestCount.Add(1) == 1 {
switch requestCount.Add(1) {
case 1:
chunks := make(chan chattest.OpenAIChunk, 1)
go func() {
defer close(chunks)
@@ -928,7 +946,12 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
<-allowFinish
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
case 2:
close(secondRequestStarted)
case 3:
close(thirdRequestStarted)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("done")...,
)
@@ -953,15 +976,7 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
require.NoError(t, err)
clock.Advance(acquireInterval).MustWait(ctx)
require.Eventually(t, func() bool {
select {
case <-streamStarted:
return true
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
testutil.TryReceive(ctx, t, streamStarted)
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
@@ -972,29 +987,11 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
require.True(t, queuedResult.Queued)
require.NotNil(t, queuedResult.QueuedMessage)
require.Eventually(t, func() bool {
select {
case <-interrupted:
return true
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
testutil.TryReceive(ctx, t, interrupted)
close(allowFinish)
require.Eventually(t, func() bool {
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
if dbErr != nil || len(queued) != 0 {
return false
}
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusPending && !fromDB.WorkerID.Valid
}, testutil.WaitMedium, testutil.IntervalFast)
chatd.WaitUntilIdleForTest(server)
assertPendingWithoutQueuedMessages(chat.ID)
// Keep the acquire loop frozen here so "queued" stays pending.
// That makes the later send queue because the chat is still busy,
@@ -1045,63 +1042,41 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
require.NoError(t, err)
clock.Advance(acquireInterval).MustWait(ctx)
require.Eventually(t, func() bool {
return requestCount.Load() >= 2
}, testutil.WaitMedium, testutil.IntervalFast)
require.Eventually(t, func() bool {
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
if dbErr != nil || len(queued) != 0 {
return false
}
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusPending && !fromDB.WorkerID.Valid
}, testutil.WaitMedium, testutil.IntervalFast)
testutil.TryReceive(ctx, t, secondRequestStarted)
chatd.WaitUntilIdleForTest(server)
assertPendingWithoutQueuedMessages(chat.ID)
clock.Advance(acquireInterval).MustWait(ctx)
testutil.TryReceive(ctx, t, thirdRequestStarted)
chatd.WaitUntilIdleForTest(server)
require.Eventually(t, func() bool {
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
if dbErr != nil || len(queued) != 0 {
return false
}
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Empty(t, queued)
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil || fromDB.Status != database.ChatStatusWaiting {
return false
}
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
userTexts := make([]string, 0, 3)
for _, message := range messages {
if message.Role != database.ChatMessageRoleUser {
continue
}
sdkMessage := db2sdk.ChatMessage(message)
if len(sdkMessage.Content) != 1 {
continue
}
userTexts = append(userTexts, sdkMessage.Content[0].Text)
userTexts := make([]string, 0, 3)
for _, message := range messages {
if message.Role != database.ChatMessageRoleUser {
continue
}
if len(userTexts) != 3 {
return false
sdkMessage := db2sdk.ChatMessage(message)
if len(sdkMessage.Content) != 1 {
continue
}
return requestCount.Load() >= 3 &&
userTexts[0] == "hello" &&
userTexts[1] == "queued" &&
userTexts[2] == "later queued"
}, testutil.WaitLong, testutil.IntervalFast)
userTexts = append(userTexts, sdkMessage.Content[0].Text)
}
require.Equal(t, []string{"hello", "queued", "later queued"}, userTexts)
}
func TestEditMessageRejectsWhenUsageLimitReached(t *testing.T) {
@@ -1844,6 +1819,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
agentToken := uuid.NewString()
// Add a startup script so the agent spends time in the
@@ -1898,7 +1874,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
)
})
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
@@ -1907,7 +1883,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
contextLimit := int64(4096)
isDefault := true
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
@@ -1915,7 +1891,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
})
require.NoError(t, err)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
@@ -1927,7 +1903,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
var chatResult codersdk.Chat
require.Eventually(t, func() bool {
got, getErr := client.GetChat(ctx, chat.ID)
got, getErr := expClient.GetChat(ctx, chat.ID)
if getErr != nil {
return false
}
@@ -1949,7 +1925,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
require.NoError(t, err)
require.Equal(t, workspaceName, workspace.Name)
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
var foundCreateWorkspaceResult bool
@@ -2023,6 +1999,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
@@ -2067,7 +2044,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
)
})
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
@@ -2076,7 +2053,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
contextLimit := int64(4096)
isDefault := true
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
@@ -2085,7 +2062,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
require.NoError(t, err)
// Create a chat with the stopped workspace pre-associated.
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
@@ -2098,7 +2075,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
var chatResult codersdk.Chat
require.Eventually(t, func() bool {
got, getErr := client.GetChat(ctx, chat.ID)
got, getErr := expClient.GetChat(ctx, chat.ID)
if getErr != nil {
return false
}
@@ -2120,7 +2097,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
require.NoError(t, err)
require.Equal(t, codersdk.WorkspaceTransitionStart, updatedWorkspace.LatestBuild.Transition)
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
// Verify start_workspace tool result exists in the chat messages.
@@ -13,11 +13,12 @@ import (
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
fantasyopenai "charm.land/fantasy/providers/openai"
"charm.land/fantasy/schema"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/chatd/chatretry"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
"github.com/coder/coder/v2/codersdk"
)
@@ -39,9 +40,10 @@ var ErrInterrupted = xerrors.New("chat interrupted")
// persistence layer is responsible for splitting these into
// separate database messages by role.
type PersistedStep struct {
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
ProviderResponseID string
// Runtime is the wall-clock duration of this step,
// covering LLM streaming, tool execution, and retries.
// Zero indicates the duration was not measured (e.g.
@@ -80,8 +82,9 @@ type RunOptions struct {
role codersdk.ChatMessageRole,
part codersdk.ChatMessagePart,
)
Compaction *CompactionOptions
ReloadMessages func(context.Context) ([]fantasy.Message, error)
Compaction *CompactionOptions
ReloadMessages func(context.Context) ([]fantasy.Message, error)
DisableChainMode func()
// OnRetry is called before each retry attempt when the LLM
// stream fails with a retryable error. It provides the attempt
@@ -245,6 +248,18 @@ func Run(ctx context.Context, opts RunOptions) error {
messages := opts.Messages
var lastUsage fantasy.Usage
var lastProviderMetadata fantasy.ProviderMetadata
needsFullHistoryReload := false
reloadFullHistory := func(stage string) error {
if opts.ReloadMessages == nil {
return nil
}
reloaded, err := opts.ReloadMessages(ctx)
if err != nil {
return xerrors.Errorf("reload messages %s: %w", stage, err)
}
messages = reloaded
return nil
}
totalSteps := 0
// When totalSteps reaches MaxSteps the inner loop exits immediately
@@ -368,10 +383,11 @@ func Run(ctx context.Context, opts RunOptions) error {
// check and here, fall back to the interrupt-safe
// path so partial content is not lost.
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
Runtime: time.Since(stepStart),
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
Runtime: time.Since(stepStart),
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
@@ -382,14 +398,41 @@ func Run(ctx context.Context, opts RunOptions) error {
lastUsage = result.usage
lastProviderMetadata = result.providerMetadata
// Append the step's response messages so that both
// inline and post-loop compaction see the full
// conversation including the latest assistant reply.
// When chain mode is active (PreviousResponseID set), exit
// it after persisting the first chained step. Continuation
// steps include tool-result messages, which fantasy rejects
// when previous_response_id is set, so we must leave chain
// mode and reload the full history before the next call.
stepMessages := result.toResponseMessages()
messages = append(messages, stepMessages...)
if hasPreviousResponseID(opts.ProviderOptions) {
clearPreviousResponseID(opts.ProviderOptions)
if opts.DisableChainMode != nil {
opts.DisableChainMode()
}
switch {
case opts.ReloadMessages != nil:
if err := reloadFullHistory("after chain mode exit"); err != nil {
return err
}
needsFullHistoryReload = false
default:
messages = append(messages, stepMessages...)
needsFullHistoryReload = false
}
} else {
messages = append(messages, stepMessages...)
}
if needsFullHistoryReload && !result.shouldContinue &&
opts.ReloadMessages != nil {
if err := reloadFullHistory("before final compaction after chain mode exit"); err != nil {
return err
}
needsFullHistoryReload = false
}
// Inline compaction.
if opts.Compaction != nil && opts.ReloadMessages != nil {
if !needsFullHistoryReload && opts.Compaction != nil && opts.ReloadMessages != nil {
did, compactErr := tryCompact(
ctx,
opts.Model,
@@ -405,14 +448,11 @@ func Run(ctx context.Context, opts RunOptions) error {
if did {
alreadyCompacted = true
compactedOnFinalStep = true
reloaded, reloadErr := opts.ReloadMessages(ctx)
if reloadErr != nil {
return xerrors.Errorf("reload messages after compaction: %w", reloadErr)
if err := reloadFullHistory("after compaction"); err != nil {
return err
}
messages = reloaded
}
}
if !result.shouldContinue {
stoppedByModel = true
break
@@ -423,9 +463,16 @@ func Run(ctx context.Context, opts RunOptions) error {
compactedOnFinalStep = false
}
if needsFullHistoryReload && stoppedByModel && opts.ReloadMessages != nil {
if err := reloadFullHistory("before post-run compaction after chain mode exit"); err != nil {
return err
}
needsFullHistoryReload = false
}
// Post-run compaction safety net: if we never compacted
// during the loop, try once at the end.
if !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil {
if !needsFullHistoryReload && !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil {
did, err := tryCompact(
ctx,
opts.Model,
@@ -973,6 +1020,85 @@ func addAnthropicPromptCaching(messages []fantasy.Message) {
}
}
// hasPreviousResponseID checks whether the provider options contain
// an OpenAI Responses entry with a non-empty PreviousResponseID.
func hasPreviousResponseID(providerOptions fantasy.ProviderOptions) bool {
if providerOptions == nil {
return false
}
for _, entry := range providerOptions {
if options, ok := entry.(*fantasyopenai.ResponsesProviderOptions); ok {
return options.PreviousResponseID != nil &&
*options.PreviousResponseID != ""
}
}
return false
}
// clearPreviousResponseID removes PreviousResponseID from the OpenAI
// Responses provider options entry, if present.
func clearPreviousResponseID(providerOptions fantasy.ProviderOptions) {
if providerOptions == nil {
return
}
for _, entry := range providerOptions {
if options, ok := entry.(*fantasyopenai.ResponsesProviderOptions); ok {
options.PreviousResponseID = nil
}
}
}
// extractOpenAIResponseID extracts the OpenAI Responses API response
// ID from provider metadata. Returns an empty string if no OpenAI
// Responses metadata is present.
func extractOpenAIResponseID(metadata fantasy.ProviderMetadata) string {
if len(metadata) == 0 {
return ""
}
for _, entry := range metadata {
if providerMetadata, ok := entry.(*fantasyopenai.ResponsesProviderMetadata); ok && providerMetadata != nil {
return providerMetadata.ResponseID
}
}
return ""
}
// extractOpenAIResponseIDIfStored returns the OpenAI response ID
// only when the provider options indicate store=true. Response IDs
// from store=false turns are not persisted server-side and cannot
// be used for chaining.
func extractOpenAIResponseIDIfStored(
providerOptions fantasy.ProviderOptions,
metadata fantasy.ProviderMetadata,
) string {
if !isResponsesStoreEnabled(providerOptions) {
return ""
}
return extractOpenAIResponseID(metadata)
}
// isResponsesStoreEnabled checks whether the OpenAI Responses
// provider options explicitly enable store=true.
func isResponsesStoreEnabled(providerOptions fantasy.ProviderOptions) bool {
if providerOptions == nil {
return false
}
for _, entry := range providerOptions {
if options, ok := entry.(*fantasyopenai.ResponsesProviderOptions); ok {
return options.Store != nil && *options.Store
}
}
return false
}
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
if len(metadata) == 0 {
return sql.NullInt64{}
@@ -14,11 +14,11 @@ import (
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
@@ -1063,6 +1063,46 @@ func ProviderOptionsFromChatModelConfig(
return result
}
// IsResponsesStoreEnabled checks if the OpenAI Responses provider
// options are present and have Store set to true. When true, the
// provider stores conversation history server-side, enabling
// follow-up chaining via PreviousResponseID.
func IsResponsesStoreEnabled(opts fantasy.ProviderOptions) bool {
if opts == nil {
return false
}
raw, ok := opts[fantasyopenai.Name]
if !ok {
return false
}
respOpts, ok := raw.(*fantasyopenai.ResponsesProviderOptions)
if !ok || respOpts == nil {
return false
}
return respOpts.Store != nil && *respOpts.Store
}
// CloneWithPreviousResponseID shallow-clones the provider options
// map and the OpenAI Responses entry, setting PreviousResponseID
// on the clone. The original map and entry are not mutated.
func CloneWithPreviousResponseID(
opts fantasy.ProviderOptions,
previousResponseID string,
) fantasy.ProviderOptions {
cloned := make(fantasy.ProviderOptions, len(opts))
for k, v := range opts {
cloned[k] = v
}
if raw, ok := cloned[fantasyopenai.Name]; ok {
if respOpts, ok := raw.(*fantasyopenai.ResponsesProviderOptions); ok && respOpts != nil {
clone := *respOpts
clone.PreviousResponseID = &previousResponseID
cloned[fantasyopenai.Name] = &clone
}
}
return cloned
}
func openAIProviderOptionsFromChatConfig(
model fantasy.LanguageModel,
options *codersdk.ChatModelOpenAIProviderOptions,
@@ -9,8 +9,8 @@ import (
fantasyvercel "charm.land/fantasy/providers/vercel"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/codersdk"
)
@@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/chatd/chattest"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
)
func TestUserAgent(t *testing.T) {

Some files were not shown because too many files have changed in this diff Show More