Compare commits

..

200 Commits

Author SHA1 Message Date
Ben Potter ed88155b3f fix: add missing return after error response in returnDAUsInternal and fix %n format verbs in tests
Bug 1 (coderd/insights.go:85): returnDAUsInternal writes a 500 error
response when GetTemplateInsightsByInterval fails, but does not return.
Execution falls through to write a second 200 OK response with empty
data. Every other error handler in the same file correctly returns after
writing the error response.

Bug 2 (coderd/database/querier_test.go): Four test assertions use %n as
a format verb for an int argument. %n is not a valid Go fmt verb, so on
assertion failure the row index renders as '%!n(int=X)' instead of the
integer. Changed to %d.
2026-03-04 15:50:58 +00:00
Kyle Carberry f56563b406 fix(site): replace modal delete confirmation with inline UI in agents admin (#22587)
## Problem

The agents admin panel (`/agents` → Admin button) is rendered inside a
Radix Dialog (`ConfigureAgentsDialog`). Deleting a model or provider
previously opened a MUI `DeleteDialog` on top, creating a modal-on-modal
situation. The two dialog systems (Radix and MUI) don't coordinate focus
trapping, scroll locking, or backdrop behavior, so the delete
confirmation was broken.

## Solution

Replace the modal `DeleteDialog` in both `ModelForm` and `ProviderForm`
with an inline confirmation strip rendered in the footer area. Clicking
"Delete" now swaps the footer to show:

- A warning message ("Are you sure? This action is irreversible.")
- Cancel and a destructive confirm button with loading spinner

This keeps everything within the existing Radix Dialog content pane — no
layering issues, no second modal.

## Changes

| File | Change |
|---|---|
| `ModelForm.tsx` | Added `isDeleting` prop, changed `onDeleteModel`
signature to async, added `confirmingDelete` state, inline confirmation
footer |
| `ProviderForm.tsx` | Removed `DeleteDialog` import/usage, replaced
with inline confirmation footer |
| `ModelsSection.tsx` | Removed `DeleteDialog` import/usage, removed
`modelToDelete` state, passes new props to `ModelForm` |
2026-03-04 10:09:13 -05:00
Matt Vollmer 77c80c30c0 docs: add Coder Agents overview page (#22584)
Adds a new documentation page at `docs/ai-coder/agents.md` describing
Coder Agents — the built-in chat interface, API, and lightweight AI
coding agent that runs in the Coder control plane.

## What's included

- Overview of what Coder Agents is and who it's for (regulated
industries, platform teams, existing Coder deployments)
- How the architecture works (agent loop in coderd, outbound to LLM
providers, connects to workspaces via existing daemon connection)
- Key features: automatic template/workspace selection, sub-agents, chat
persistence, message queuing
- Security benefits of the control plane architecture (no API keys in
workspaces, simpler network boundaries, centralized enforced control,
user identity attached)
- LLM provider support table (verified against
`coderd/chatd/chatprovider/chatprovider.go`)
- Built-in tools reference
- Comparison to Coder Tasks
- Product status (internal preview, early access next)
2026-03-04 10:06:48 -05:00
Ethan e738ff5299 ci: remove dylib build pipeline (#22592)
## Summary

The macOS `.dylib` is only used by Coder Desktop macOS v0.7.2 or older.
v0.7.2 was released in August 2025. v0.8.0 of Coder Desktop macOS, also
released in August 2025, uses a signed Coder slim binary from the
deployment instead.

It's unlikely customers will be using Coder Desktop macOS v0.7.2 and the
next release of Coder simultaneously, so I think we can safely remove
this process, given it slows down CI & release processes.

## Changes

- **Makefile**: Remove `DYLIB_ARCHES`, `CODER_DYLIBS` variables and
`build/coder-dylib` target
- **scripts/build_go.sh**: Remove `--dylib` flag and all dylib-specific
logic (c-shared buildmode, CGO, plist embedding, vpn/dylib entrypoint)
- **scripts/sign_darwin.sh**: Remove dylib-specific comment
- **CI (ci.yaml)**: Remove `build-dylib` job, artifact download/insert
steps, and `build-dylib` dependency from `build` job
- **Release (release.yaml)**: Remove `build-dylib` job, artifact
download/insert steps, and `build-dylib` dependency from `release` job
- **vpn/dylib/**: Delete entire directory (`lib.go` + `info.plist.tmpl`)
- **vpn/router.go, vpn/dns.go**: Clean up comments referencing dylib

The slim and fat binary builds are completely unaffected — the dylib was
an independent build target with its own CI job.

_Generated by mux but reviewed by a human_
2026-03-05 01:50:50 +11:00
Kyle Carberry 1635b18856 fix: persist draft message in localStorage on agent detail page (#22600)
## Problem

On the `/agents/:agentId` detail page, text typed into the chat input
was lost when navigating away and returning. The empty-state page
(`/agents`) already persisted drafts via `localStorage`, but individual
conversation pages did not.

## Solution

Adds per-conversation draft persistence to `useConversationEditingState`
in `AgentDetail.tsx`, following the same patterns used elsewhere in the
agents page:

- Drafts are stored under `agents.draft-input.<chatID>` keys
- The saved draft is read as the editor's initial value on mount
- `localStorage` is updated on every content change
- The key is removed when the input is cleared or a message is sent
successfully
2026-03-04 14:42:13 +00:00
Kacper Sawicki 52a42af1ca chore(deps): bump clistat to v1.2.1 (#22599)
Bumps `github.com/coder/clistat` from v1.2.0 to v1.2.1.
2026-03-04 15:29:00 +01:00
Danielle Maywood 90f686d684 feat(agents): add unarchive agent support (#22579) 2026-03-04 14:08:12 +00:00
Sas Swart 8c09df52f9 fix(coderd): use WaitSuperLong in TestReinitializeAgent (#22593)
Fixes coder/internal#642

We recently fixed Windows specific flakes for this test and reenabled
it. It then failed intermittently due to context deadline expiration.
The temporary path created on Windows contained invalid characters. This
resulted in a silent startup script failure on Windows. The test then
fruitlessly waited until context expiration. The test now uses a valid
path on Windows.
2026-03-04 15:22:43 +02:00
Kyle Carberry 012a0497ce fix(agents): remove optimistic message rendering and fix auto-promote delivery (#22588)
## Problem

Two bugs in the agents chat flow:

1. **Optimistic rendering glitch**: When sending a message while the
agent is busy, a fake message with a negative ID appears in the
timeline, then gets rolled back to the queued state. This causes a
jarring flash.

2. **Auto-promoted messages not appearing**: When the server
auto-promotes a queued message after finishing a task, the promoted user
message doesn't show up in the timeline until the LLM finishes its
response.

## Root Causes

**Bug 1**: The optimistic rendering system injected placeholder messages
with `id: -Date.now()` into the store. When the server responded with
`queued: true`, the optimistic message was rolled back — but the user
had already seen it flash in the timeline.

**Bug 2**: In `processChat`'s deferred cleanup, the auto-promoted
message was published via `publishEvent()`, which only delivers to local
in-process stream subscribers. The SSE subscriber goroutine only
forwards `message_part` events from the local channel — it ignores
`message` events. Durable events reach the SSE client via pubsub → DB
read, but `publishEvent` doesn't trigger a pubsub notification. The
explicit `PromoteQueued` endpoint correctly used `publishMessage()`
(which does both), but the auto-promote path did not.

## Changes

### Frontend (`site/`)
- **AgentDetail.tsx**: Remove optimistic message injection from send and
edit flows. Instead, use the `CreateChatMessageResponse.message` from
the POST response to insert the real server message into the store
immediately.
- **ChatContext.ts**: Remove the negative-ID cleanup logic from
`upsertDurableMessage` that stripped optimistic placeholders when real
messages arrived.
- **chatStore.test.ts**: Remove 2 tests for negative-ID optimistic
message behavior.

### Backend (`coderd/chatd/`)
- **chatd.go**: In `processChat` cleanup, replace `publishEvent()` with
`publishMessage()` for auto-promoted messages. This ensures the pubsub
notification (`AfterMessageID`) is sent, so SSE subscribers read the new
message from the DB immediately.
2026-03-04 07:49:39 -05:00
Danielle Maywood f28f56d02c test(coderd/rbac): parallelize TestRolePermissions subtests (#22259) 2026-03-04 12:47:39 +00:00
Jeremy Ruppel f07fdce20a flake: add page event logging to e2e tests (#22569)
I'm having a hard time reproducing [this
Heisenbug](https://github.com/coder/internal/issues/1154) in PR CI, but
it seems to happen pretty often on main, so I would like to add some
logging for a few more page events to the ones we already have.
2026-03-04 07:39:20 -05:00
Danielle Maywood a0b3a32cd3 fix(site): refactor agents sidebar timestamp/action cell (#22595) 2026-03-04 11:36:54 +00:00
Sas Swart cfcb81fb0f fix: user status change chart accommodates DST (#22191)
closes https://github.com/coder/internal/issues/464

# Summary

This PR resolves a flaky test that was sensitive to DST transitions in
various time zones. The root of the flake was:
* a bug; the query and its tests assume 24 hours per day
* the tests used local system time, which resulted in failures for dates
proximal to DST transitions

# Changes

Query:

The original query assumed 24 hour intervals between each day, which is
not a valid assumption. It now increments `1 day` at a time.

Database tests:

Database level tests for the query all assumed 24 hour days. They now
increment in DST-aware days instead. Instead of using time.Now() as a
base for testing, the test uses a series of dates over the course of an
entire year, to ensure that DST transition dates are present in every
test run.

# API Endpoint

The endpoint that delivers the user status chart now accepts an IANA
timezone name as a parameter and passes it, keeping the existing offset
as a fallback, to the database query.

API level tests were added to ensure the correct response form and error
behaviour. Correctness of content is tested at the database level.
2026-03-04 12:54:39 +02:00
Danielle Maywood 2882e36222 fix(site): move chat input outside flex-col-reverse scroller (#22585) 2026-03-04 01:04:04 +00:00
Mathias Fredriksson 13411c8a8a docs: add task lifecycle and agent compatibility pages (#22222)
Closes coder/internal#1359
Closes coder/internal#1329
2026-03-04 02:39:48 +02:00
Kyle Carberry 47199ab475 refactor(site): replace bespoke chat model provider UI with schema-driven rendering (#22577)
## Summary

Replace hand-coded per-provider field components, form state types,
validation schemas, and builder functions with generic schema-driven
code that reads from the auto-generated
`chatModelOptionsGenerated.json`.

## Changes

### `ModelConfigFields.tsx` (492 → 341 lines)
- Remove 6 per-provider components (`OpenAIFields`, `AnthropicFields`,
`GoogleFields`, `OpenAICompatFields`, `OpenRouterFields`,
`VercelFields`)
- Remove exported option arrays (`modelConfigReasoningEffortOptions`,
etc.)
- Add `renderSchemaField()` that dispatches to
`InputField`/`SelectField`/`JSONField` based on `field.input_type` from
the generated schema
- `ModelConfigFields` now calls `getVisibleProviderFields()` instead of
a switch statement
- `GeneralModelConfigFields` now calls `getVisibleGeneralFields()`
instead of hard-coding 6 InputField instances

### `modelConfigFormLogic.ts` (742 → 525 lines)
- Remove 6 per-provider form state types and empty defaults
- Remove 6 per-provider Yup validation schemas
- Remove 6 per-provider builder functions (`buildOpenAIOptions`, etc.)
- Remove 2 switch-case dispatch blocks (validation + build)
- Add `buildEmptyProviderState()` that walks schema fields to create
empty form state
- Add schema-driven `extractModelConfigFormState()` and
`buildModelConfigFromForm()`
- Add `yupTestForField()` + `buildYupSchema()` generating Yup validation
from field metadata
- Lazy-cache per-provider Yup schemas for performance

### `modelConfigFormLogic.test.ts`
- All 83 tests updated for the new nested state shape
- Uses `toContain` for error message assertions since labels now come
from schema descriptions

## Motivation

The auto-generated schema (`chatModelOptionsGenerated.json`) was merged
in #22568 but not yet consumed by the UI. This PR wires it up so that
when a new provider or field is added in Go (`codersdk/chats.go`),
running `make gen` regenerates the JSON schema and the UI automatically
picks up the new fields — no manual TypeScript changes needed.

**Production code reduced from 1234 to 866 lines (-30%).**
2026-03-03 17:52:35 -05:00
Kyle Carberry 4ee5306eca fix(site): request notification permission before push subscription (#22576)
## Problem

The subscribe flow in `useWebpushNotifications` called
`pushManager.subscribe()` without first requesting the `Notification`
permission. When the browser permission state is `"denied"` (e.g. from a
previous prompt dismissal), the browser throws:

```
DOMException: Registration failed - permission denied
```

This surfaced as a confusing error toast on the agents page. The error
has nothing to do with Coder RBAC roles — it's the browser denying the
push subscription because notification permission was previously
declined. An admin who had granted browser permission wouldn't see this;
a user who previously dismissed or denied the prompt would.

## Fix

Added an explicit `Notification.requestPermission()` call before
`pushManager.subscribe()`. This:

1. **Re-prompts** the user if the permission state is `"default"` (not
yet decided)
2. **Throws a clear, actionable error** if the permission is `"denied"`:
*"Notifications are blocked by your browser. Please allow notifications
for this site in your browser settings."*
3. **Only proceeds** to `pushManager.subscribe()` after permission is
confirmed as `"granted"`

## Tests

New test file `useWebpushNotifications.jest.ts`:
- **requests notification permission before subscribing** — verifies
`requestPermission()` is called before `pushManager.subscribe()`
- **throws a clear error when permission is denied** — verifies the
user-friendly error message
- **does not call pushManager.subscribe when permission is denied** —
verifies we bail out early
2026-03-03 17:13:31 -05:00
Matt Vollmer 39bde165b8 fix(site): open View Workspace link in new window on agents page (#22578)
On the `/agents` page, the "View Workspace" link in the header dropdown
menu was navigating in the same tab via `navigate()`. This changes it to
`window.open(workspaceRoute, "_blank")` so it opens in a new browser
window/tab instead.

It's frustrating when I want to view my workspace and then I have to go
back and find my chat.
2026-03-03 17:10:11 -05:00
Kyle Carberry f758443f44 feat(codersdk): generate chat model provider options schema from Go structs (#22568) 2026-03-03 21:29:58 +00:00
Kyle Carberry 5b1cf4a6a3 fix(chatd): start stream buffering before publishing running status (#22571)
## Problem

There is a race condition in the chat stream reconnect path. When a
client connects (or reconnects) to `/stream`, sometimes they only see a
`status: running` event but never receive any `message_part` events —
the stream appears stuck.

## Root Cause

In `processChat`, the sequence is:

1. `publishStatus(running)` — broadcasts `status: running` to all
subscribers and via pubsub.
2. `runChat()` is called.
3. Inside `runChat`, there's significant setup work (model resolution,
DB queries, title generation, prompt building, instruction resolution).
4. Only **after** all that setup does `runChat` set `buffering = true`
on the stream state.

If a client connects to `/stream` between steps 1 and 4:
- `Subscribe()` reads `chat.Status == running` from the DB, so it
includes `status: running` in the snapshot.
- But `buffering` is still `false`, so `subscribeToStream` returns an
**empty** local snapshot (no message_parts).
- `publishToStream` **drops** all `message_part` events when `buffering`
is false.
- Result: client sees `running` but never gets any streaming content.

## Fix

Move the `buffering = true` setup (and its deferred cleanup) from
`runChat` into `processChat`, right before `publishStatus(running)`.
This guarantees the buffer is active before any subscriber can observe
`status: running`, so:
- The snapshot always includes any in-flight `message_part` events.
- `publishToStream` never drops parts because buffering is already on.
2026-03-03 21:27:59 +00:00
Danielle Maywood f98761ff67 refactor(site): use button instead of div role="button" (#22575) 2026-03-03 21:26:01 +00:00
Steven Masley f6b4b7edab ci: remove sqlc push to cloud (#22574)
I left a vestigial piece, whooops
2026-03-03 14:49:00 -06:00
Danielle Maywood d2d956edb1 fix: add archived query parameter to chat list endpoint (#22562)
Despite the SDK type having an `Archived` field for chats, this data was
never fetched from the database — the `GetChatsByOwnerID` query
hardcoded `AND archived = false`, and the `convertChat` function never
mapped the field.

This PR adds an optional `archived` query parameter to `GET
/api/experimental/chats`:

| Value | Behavior |
|-------|----------|
| *(not provided)* | Returns all chats (active and archived) |
| `archived=false` | Returns only non-archived chats |
| `archived=true` | Returns only archived chats |

This follows the same pattern used by template versions
(`sqlc.narg('archived')` nullable boolean).

Also fixes `convertChat` to populate the `Archived` field in API
responses, which was never being set despite existing on the SDK type.
2026-03-03 20:39:19 +00:00
Kyle Carberry 8a2635285b fix(site): remove stale ArchivedAgentsSearchAutoExpands story (#22573)
The search input was removed from `AgentsSidebar` but the
`ArchivedAgentsSearchAutoExpands` story still referenced the `Search
agents...` placeholder, causing the Storybook interaction test to fail:

```
within(<div#storybook-root>).getByPlaceholderText("Search agents...")
Unable to find an element with the placeholder text of: Search agents...
```

This PR removes the stale story.
2026-03-03 20:28:04 +00:00
Steven Masley 8ea0c2f3bc ci: remove ci action to push schema to sqlc cloud (#22572)
SQLc cloud no longer exists
2026-03-03 14:21:07 -06:00
Kyle Carberry 810b509290 feat: refactor the agents admin UI layout (#22567)
I am working on a subsequent change to make the fields auto-generated
with `make gen` from the Go code itself, rather than us needing to
create a UI compatibility layer.

Once the above is done, I'll be adding in the payload so users can very
easily just click "Opus 4.6" to add the model, and the config values
will be set appropriately.

This is really just UI changes, nothing functionally should change here.
But the code will be cleaned up a lot post the above changes.

<img width="1197" height="978" alt="image"
src="https://github.com/user-attachments/assets/45f9afff-89bb-47a6-b9a1-534f50a9676e"
/>
<img width="1180" height="949" alt="image"
src="https://github.com/user-attachments/assets/b3fd963f-1c1d-4d2c-b501-ac8118b019ec"
/>
<img width="1185" height="957" alt="image"
src="https://github.com/user-attachments/assets/08faca29-2b38-476a-adab-0bd8ab17ddcc"
/>
2026-03-03 15:19:07 -05:00
Jeremy Ruppel b73f21662b flake: verify parameters in parallel in e2e tests (#22557)
This is an attempt to address coder/internal#1154

Tests appear to fail often on `verifyParameters`, which asserts input
visibility and value in series for all expected parameters. This change
makes the same assertions in parallel, hopefully completing before
timeout.
2026-03-03 14:56:41 -05:00
Danielle Maywood 6acdd6ca7d fix: wire agents-tab-visible metadata to experiments flag (#22553) 2026-03-03 13:51:10 -06:00
Zach 5b7377c375 feat: add Prometheus metrics for boundary log drop reporting (#22521)
Add Prometheus metrics to the boundary log proxy for observability:
- batches_dropped_total (reason: buffer_full, forward_failed)
- logs_dropped_total (reason: buffer_full, forward_failed,
  boundary_channel_full, boundary_batch_full)
- batches_forwarded_total

Also add BoundaryStatus to the BoundaryMessage envelope so boundary
can report dropped log counts as a separate wire message. The agent
records these as Prometheus metrics, making boundary-side data loss
visible. Backwards compatibility for older versions of boundary is maintained.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 12:42:34 -07:00
Danny Kopping 9b5573d7fa feat: store tool call IDs to determine interception lineage (#22246)
Adds database columns and server-side logic to track interception lineage via tool call IDs. When an interception ends, the server resolves the correlating tool call ID to find the parent interception and links them via `parent_id`.

New `provider_tool_call_id` column on `aibridge_tool_usages` and `parent_id` column on `aibridge_interceptions`, with indexes for lookup. `findParentInterceptionID` queries by tool call ID and filters out the current interception to find the parent.

Adapted from the [coder/coder `dk/prompt_provenance_poc`](https://github.com/coder/coder/compare/main...dk/prompt_provenance_poc) branch.
Depends on [coder/aibridge#188](https://github.com/coder/aibridge/pull/188).  
  
Closes https://github.com/coder/internal/issues/1334
2026-03-03 21:07:01 +02:00
Danny Kopping 1b08bc76a6 feat: store tool call IDs to determine interception lineage (#22246)
Adds database columns and server-side logic to track interception lineage via tool call IDs. When an interception ends, the server resolves the correlating tool call ID to find the parent interception and links them via `parent_id`.

New `provider_tool_call_id` column on `aibridge_tool_usages` and `parent_id` column on `aibridge_interceptions`, with indexes for lookup. `findParentInterceptionID` queries by tool call ID and filters out the current interception to find the parent.

Adapted from the [coder/coder `dk/prompt_provenance_poc`](https://github.com/coder/coder/compare/main...dk/prompt_provenance_poc) branch.
Depends on [coder/aibridge#188](https://github.com/coder/aibridge/pull/188).  
  
Closes https://github.com/coder/internal/issues/1334
2026-03-03 21:04:41 +02:00
Kyle Carberry c05fbfec6c feat(site): restyle agents sidebar New Agent button (#22555)
Removes the search input and restyles the New Agent button in the agents
sidebar:

- Removed the search input box
- Replaced the outlined button with a subtle, left-aligned button
featuring a `SquarePenIcon`
- Button icon and text alignment matches the tree node items in the
sidebar

<img width="769" height="337" alt="image"
src="https://github.com/user-attachments/assets/2284c8c0-6294-4823-9ce0-5cc72b0d0054"
/>
2026-03-03 13:26:34 -05:00
Steven Masley f49dea683c chore: prematurely refresh oidc token near expiry during workspace build (#22502)
Closes https://github.com/coder/coder/issues/22429
2026-03-03 18:13:00 +00:00
Spike Curtis 56eb57caf4 chore: enable agent socket by default (#22352)
relates to #21335

Enables the agent socket by default and updates docs to strike references to having to enable it.

The PRs in this stack change the MCP server that Tasks use to update their status to rely on the agent socket, rather than directly dialing Coderd with the agent token.

Default disable was a reasonable default when it was only used for the experimental script ordering features, but now that we want to use it for Tasks, it should be default on.
2026-03-03 21:23:59 +04:00
Kyle Carberry 2ceac319b8 fix(site): eagerly fetch API key for Open in Cursor/VS Code buttons (#22554)
## Problem

The **Open in Cursor** and **Open in VS Code** buttons on the agent
detail page were broken. Clicking them did nothing.

### Root Cause

The `handleOpenInEditor` handler in `AgentDetail.tsx` called
`window.location.assign()` with a custom protocol URI (`vscode://` or
`cursor://`) **after** an `await API.getApiKey()` call. This creates an
async boundary that breaks the browser's user gesture chain, causing
custom protocol navigations (`vscode://`, `cursor://`) to be silently
blocked by the browser.

The handler was invoked from a Radix `DropdownMenuItem.onSelect`, which
adds another layer of event indirection that makes the gesture chain
more fragile.

In contrast, the workspace page's `VSCodeDesktopButton` works because it
uses a direct `onClick` handler on a button element.

## Fix

- **Eagerly fetch and cache the API key** via `useQuery` when workspace
and agent data is available
- **Make `handleOpenInEditor` synchronous** — it reads the cached key
instead of awaiting a network call, keeping `window.location.assign()`
within the original user gesture context
- **Disable buttons** while the API key is still loading
(`canOpenEditors` now gates on key availability)
- **Simplify** the `onOpenInEditor` callback (remove `void` async
wrapper)
2026-03-03 16:54:27 +00:00
Steven Masley bca638a498 feat: validate prebuild presets using dynamic parameter validation (#21858)
Prebuilds need to be valid. Before this change, you can push a template
version that's preset will fail when making a prebuild. This PR ensures
all presets that are used for prebuilds are valid
2026-03-03 16:50:18 +00:00
Cian Johnston 8a095ae722 fix(site): remove ugly webpush button (#22560) 2026-03-03 16:44:46 +00:00
Kyle Carberry 059ed7ab5c fix(chatd): return chat to pending when server shuts down during successful completion (#22559)
## Problem

Flaky test:
`TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica`
(coder/internal#1371)

The test intermittently fails because the chat ends up in `waiting`
status instead of `pending` after server shutdown.

## Root Cause

There is a race condition in `processChat` where `runChat` completes
successfully just as the server context is being canceled during
`Close()`. The sequence:

1. Server calls `Close()`, canceling the server context.
2. The LLM HTTP response has already been fully written by the mock
server (the stream closes normally before context cancellation
propagates to the HTTP client).
3. `runChat` returns `nil` (success) instead of `context.Canceled`.
4. The existing `isShutdownCancellation` check only runs when `runChat`
returns an error, so the shutdown is not detected.
5. `processChat`'s deferred cleanup marks the chat as `waiting` instead
of `pending`.
6. The test's assertion that the chat is `pending` never becomes true.

This race is timing-dependent — it only triggers when the mock server's
HTTP response completes in the narrow window between context
cancellation being initiated and it propagating through the HTTP
transport layer.

## Fix

Add a server context check after `runChat` returns successfully. If the
server is shutting down (`ctx.Err() != nil`), override the status to
`pending` so another replica can pick up the chat.

This is the same pattern already used for the error path
(`isShutdownCancellation`), extended to cover the success path.
2026-03-03 11:34:08 -05:00
Mathias Fredriksson 96cfb7d06a docs(.claude/docs): add modern Go reference for AI agents (#22558)
Add a Go 1.18-1.26 reference document (`.claude/docs/GO.md`) to guide AI
agents toward modern Go idioms.
2026-03-03 18:24:28 +02:00
Zach 66954aead0 feat: add TagV2 BoundaryMessage envelope protocol (#22520)
Extend the wire protocol for the boundary <-> agent unix socket with
a message envelope.

The envelope creates a boundary <-> agent data path that is separate
from the agent <-> coderd path. This lets boundary send operational
metadata (drop counts, configuration like jail type, capabilities)
that the agent can act on locally (e.g. Prometheus metrics) or use
to enrich outbound requests, without polluting the coderd-facing proto
with fields coderd never consumes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-03 09:13:11 -07:00
Danielle Maywood cdb7145982 feat(site): add collapsible archived agents section to sidebar (#22551) 2026-03-03 15:30:59 +00:00
Kyle Carberry 2d7009e50d test: reduce unnecessary sleep durations in tests (#22552)
## Summary

Removes `time.Sleep` calls in two test files by replacing them with
deterministic or event-driven alternatives.

### Changes

**`coderd/provisionerjobs_test.go`** (34.5s → 0.25s)

Replaced `time.Sleep(1500ms)` with a direct SQL `UPDATE` to bump
`created_at` by 2 seconds. The sleep existed purely to ensure different
timestamps for sort-order testing. The fix is deterministic and cannot
flake. Uses `NewDBWithSQLDB` (the test already required real Postgres
via `WithDumpOnFailure`).

**`coderd/database/pubsub/pubsub_test.go`** (2.05s → 1.3s)

Replaced `time.Sleep(1s)` with a `testutil.Eventually` retry loop that
publishes and checks for subscriber receipt. This is the idiomatic
pattern in the codebase. The old sleep waited for pq.Listener to
re-issue LISTEN after reconnect; the new code polls until it actually
works.
2026-03-03 10:19:00 -05:00
Kyle Carberry 10a33ebc75 test: reduce Await* polling interval from 250ms to 25ms (#22536)
## Summary

Change the four main `coderdtest` Await helper functions to poll at
`IntervalFast` (25ms) instead of `IntervalMedium` (250ms):

- `AwaitTemplateVersionJobCompleted`
- `AwaitWorkspaceBuildJobCompleted`
- `WorkspaceAgentWaiter.WaitFor`
- `WorkspaceAgentWaiter.Wait`

These are called **~855 times** across the test suite. Each call
previously wasted ~125ms on average waiting for the next poll tick.
`AwaitTemplateVersionJobRunning` already used `IntervalFast` — this
makes all Await helpers consistent.

## Measured Impact

Local benchmarks (postgres, `-short -count=1 -p 8 -parallel 8
-tags=testsmallbatch`):

| Package | Before | After | Delta |
|---|---|---|---|
| enterprise/coderd | 90.8s | 76.0s | **-16.3%** |
| coderd | 65.6s | 56.5s | **-13.8%** |
| cli | 57.9s | 37.8s | **-34.7%** |
| enterprise (root) | 41.1s | 39.9s | -2.9% |
| **Sum of all packages** | **623s** | **543s** | **-12.8%** |

Zero test failures across all 199 packages.
2026-03-03 13:48:58 +00:00
Ehab Younes 9d2aed88c4 fix: register task pause/resume routes under /api/v2 (#22544)
The pause/resume endpoints were only registered under /api/experimental
but the frontend and Go SDK were calling /api/v2, resulting in 404s.
Register the routes in the v2 group, update the SDK client paths, and
fix swagger annotations (Accept → Produce) since these POST endpoints
have no request body.
2026-03-03 16:34:33 +03:00
Jake Howell e2a3b99d3a fix: remove mui components from <WorkspaceScheduleForm /> (#22232)
This pull-request removes the last instance of the
`@mui/material/Switch` from the codebase, whilst also cleaning up the
`<WorkspaceScheduleForm />` page of MUI.

<img width="1067" height="666" alt="image"
src="https://github.com/user-attachments/assets/b32094f6-f1a4-42fc-b927-64749e131f1b"
/>
2026-03-04 00:30:02 +11:00
Jake Howell 02328605a9 feat: <RequestLogsPage /> not enabled message (#22546)
This pull-request implements a message that alerts when the user has AI
Governance entitled in their license but not necessarily enabled, this
is an improvement to #22385 where the content simply didn't render (due
to a failed request).

<img width="1403" height="400" alt="image"
src="https://github.com/user-attachments/assets/54685ce1-5cc1-421f-b290-de9b3d160c03"
/>

<img width="1411" height="824" alt="image"
src="https://github.com/user-attachments/assets/4d44fa18-0914-4e51-a415-c511e5b13e98"
/>
2026-03-04 00:25:38 +11:00
Danielle Maywood 17f5fa452e fix: remove opacity-50 from agents page logo so it renders black (#22547) 2026-03-03 13:10:12 +00:00
Jake Howell b4a53acfd6 fix: resolve <Checkbox /> classes (#22540)
This pull-request cleans up various issues we I noticed while using the
`<Checkbox />` element.

* Margins were missing around the outsides of the `<Checkbox />`
components.
* Resolved the story of the `WithLabel` so things are lined up.
* Lowered the `borderRadius` down to `2px` (This can be cleaned up by
Tailwind 4 later).
* Refined checkbox styling across focus, disabled, checked, and
indeterminate states to be inline with Figma.
* Simplified checkbox indicator rendering and centered the icon with
absolute positioning.
2026-03-03 23:54:28 +11:00
Danielle Maywood 2a3b6643ea fix(site): cap chat input height and make it scrollable (#22545) 2026-03-03 12:35:39 +00:00
Cian Johnston 88465ea48a chore: misc improvements to compose.dev.yaml (#22541)
- Unexposes port 5432 so it doesn't conflict with existing databases.
You can still hop into the DB if you need.
- Updates multiple CODER_ envs to sensible defaults, overrideable via env
2026-03-03 12:07:57 +00:00
Jake Howell 8aebd73466 feat: implement new default monospace font Geist Mono (#22081)
This pull-request follows up #22060

Felt wrong to only make use of Geist when there is a Monospace variant
here too. Felt best we default to this as the default font as its inline
with the rest of the application. This also updates the lower line for
Workspace Statistics 🙂
2026-03-03 12:00:50 +00:00
Jake Howell aa9fafa372 fix: resolve <Avatar /> incorrect sizes (#22538)
This pull-request updates our icons to be inline with the Figma file.
They were slightly too small in the two variants of `--avatar-default`
and `--avatar-sm`. Now these are inline with what we have defined and
using the correct variants in the breadcrumbs.
2026-03-03 22:48:41 +11:00
Cian Johnston 3c4a416b55 feat(site): add Open Terminal and Copy SSH Command to agent chat TopBar (#22529)
Adds two new items to the agent chat TopBar dropdown menu:

- **Open Terminal**: opens the workspace web terminal in a new browser
window, reusing the existing `getTerminalHref`/`openAppInNewWindow`
infra.
- **Copy SSH Command**: copies the SSH command (e.g. `ssh
agent.workspace.owner.suffix`) to the clipboard with a toast
confirmation. Only shown when the deployment SSH hostname suffix is
configured.

Both items appear after a separator below the existing editor/workspace
actions.

## Changes

| File | What |
|---|---|
| `TopBar.tsx` | Added `Open Terminal` and `Copy SSH Command` dropdown
items with separator, `TerminalIcon`/`CopyIcon`, toast on copy |
| `AgentDetail.tsx` | Wired up `getTerminalHref`, `openAppInNewWindow`,
`deploymentSSHConfig` query, and passed new props to TopBar in all 3
render paths |
| `TopBar.stories.tsx` | Added new fields to default story props |
2026-03-03 11:44:39 +00:00
Michael Suchacz 2203b259e6 fix(dogfood): upgrade Rust from apt (1.75) to rustup stable (#22458)
The Ubuntu Jammy `cargo` apt package provides Rust 1.75, which is too
old for transitive dependencies requiring edition 2024 (Rust 1.85+).

**Changes:**
- Replace apt `cargo` with a rustup-based install (stable channel,
minimal profile).
- Override `CARGO_HOME` to `/home/coder/.cargo` after `USER coder` so
cargo registry/cache writes go to the user's home (the rustup-installed
binaries remain on PATH via `/usr/local/cargo/bin`).
- Add `--fail` to all `curl` commands in the tool-download block so HTTP
errors fail fast with clear messages instead of silently piping error
pages into `tar`.
- Bump kube-linter 0.6.3 → 0.8.1 and trivy 0.41.0 → 0.69.2 (old releases
were removed from GitHub, causing persistent 404s).
2026-03-03 12:11:04 +01:00
Danielle Maywood c483bfa24f refactor(site): convert archive agent callbacks to React Query mutations (#22542) 2026-03-03 11:03:59 +00:00
Cian Johnston 517cb0ce73 refactor(webpush): use RequireExperimentWithDevBypass middleware (#22525)
Replace manual experiment checks in web-push handlers with the
`RequireExperimentWithDevBypass` middleware on the route group, matching
the pattern used by OAuth2, Agents, and MCP experiments.

## Changes

- **`coderd/coderd.go`**: Add `RequireExperimentWithDevBypass`
middleware to `/webpush` route group
- **`coderd/webpush.go`**: Remove inline
`api.Experiments.Enabled(codersdk.ExperimentWebPush)` checks from all
three handlers
- **`cli/server.go`**: Gate webpush dispatcher initialization with
`buildinfo.IsDev()` fallback so dev builds always init the real
dispatcher
- **`coderd/webpush_test.go`**: Remove experiment enablement from tests
(dev bypass handles it)

Net effect: -26 lines removed, +5 added.

Created using whatchamacallits (Opus 4.6 Max)
2026-03-03 09:49:04 +00:00
Sas Swart e563766722 tests: re-enable 'TestReinitializeAgent' on Windows (#22488)
closes https://github.com/coder/internal/issues/642

This PR:
* re-enables `func TestReinitializeAgent(t *testing.T)`
* adjusts it to use a Windows specific command in Windows environments
2026-03-03 11:22:02 +02:00
Mathias Fredriksson b80dbd2d4e test(cli): fix flaky TestProvisioners_Golden (#22491)
Fixes coder/internal#449
2026-03-03 08:47:34 +00:00
Andrew Aquino ef2e408c0c chore: add GitHub Action to build+deploy coder.com whenever docs paths change (#22283)
## problem

Fixes an issue where updates to docs resulted in docs links returning
HTTP 404, sometimes taking 4-12 hours before returning HTTP 200 (OK).

coder.com is deployed to Vercel from a separate Next.js repo, which has
no knowledge of when docs pages in this repo get updated.

### examples (non-exhaustive)

PR | 404 description
---|---
#19625 | URL for https://coder.com/docs/install/offline was updated to
https://coder.com/docs/install/airgap, but the latter returned 404 for 3
hr 56 min after the PR was merged
#21434 | URLs https://coder.com/docs/ai-coder/nsjail and
https://coder.com/docs/ai-coder/landjail were added, but both paths
404ed for 1 hr 30 min after the PR was merged. Note that these paths
have changed since then--don't be alarmed if clicking those links
returns 404s while reviewing this PR
#21708 | URL https://coder.com/docs/ai-coder/boundary/agent-boundary was
added, but it returned 404 for 1 hr 19 min after the PR was merged

## solution

All 3 PRs listed above modify manifest.json. This file is fetched during
coder.com's `getStaticPaths` for docs pages, defining which docs URLs
get statically generated at build time. In the latter 2 cases, the 404s
were resolved by manually triggering a redeploy of coder.com in the
Vercel dashboard.

The new CI workflow in this PR automatically triggers a Vercel deploy
hook ([see
docs](https://vercel.com/docs/deploy-hooks#triggering-a-deploy-hook))
with a POST request that runs whenever commits are pushed to main that
modify manifest.json. The deploy hook initiates a new build+deploy of
the coder.com Next.js app, which reruns `getStaticPaths`, updating docs
pages' URLs.

**Note:** I have not tested this workflow yet. I will verify that it
works after this PR is merged. I confirmed in a local terminal that the
webhook URL does successfully initiate a new Vercel build. I also tested
with a malformed URL and received error JSON output, so if the action
fails for some reason, we should see error output in the workflow logs
([example](https://github.com/coder/coder/actions/runs/22361453442/job/64722503802)).
2026-03-03 00:38:49 -08:00
Kyle Carberry 56f95a3e6d fix: scope git askpass diff status updates to initiating chat (#22534)
## Problem

When the git askpass flow triggered diff status refreshes, it updated
**every chat** connected to the workspace. This was wasteful and could
cause confusing status updates on unrelated chats.

## Solution

Thread the chat ID through the entire git askpass flow so only the chat
that initiated the git operation gets updated:

1. **`coderd/chatd/chattool/execute.go`** — Sets `CODER_CHAT_ID` env var
on spawned processes (alongside the existing `CODER_CHAT_AGENT`)
2. **`cli/gitaskpass.go`** — Reads `CODER_CHAT_ID` from the environment
and sends it as a `chat_id` query parameter in the `ExternalAuthRequest`
3. **`codersdk/agentsdk/agentsdk.go`** — Adds `ChatID` field to
`ExternalAuthRequest` and encodes it as a query param
4. **`coderd/workspaceagents.go`** — Parses `chat_id` query param and
passes it through to `storeChatGitRef` and
`triggerWorkspaceChatDiffStatusRefresh`
5. **`coderd/chats.go`** — `storeChatGitRef` and
`refreshWorkspaceChatDiffStatuses` now scope updates to just the
initiating chat when a chat ID is provided, falling back to
all-workspace-chats behavior for backwards compatibility (non-chat git
operations)
2026-03-02 22:52:39 -05:00
Kayla はな 2bdf80d452 fix: disable sharing ui when sharing is unavailable (#22390)
Currently the sharing UI is only hidden under certain circumstances,
rather than on a permission basis. This makes it permissions based, and
makes some backend changes to make sure permissions are correct.
2026-03-03 02:04:55 +00:00
Kyle Carberry b7a7683ac0 fix(chatd): harden cross-replica relay for chat stream parts (#22533)
## Problem

Subscribers connecting to a different replica than the one running the
chat see full messages appear but no streaming partials (`message_part`
events). The relay mechanism that forwards ephemeral parts across
replicas had several bugs.

## Root Causes

1. **`openRelay()` blocked the event loop** — The WebSocket dial (TCP +
TLS + HTTP upgrade) to the worker replica ran synchronously inside the
select loop. While dialing, no events could be processed, channels
filled up, and parts were silently dropped.

2. **Relay drops were permanent** — When the relay WebSocket closed
mid-stream, `relayParts` was set to nil and never reopened. No status
notification would re-trigger it since the chat was still running on the
same worker.

3. **`drainInitial` snapshot race** — The `default` case in the initial
drain loop caused the snapshot to be empty if the remote hadn't flushed
data yet (common immediately after WebSocket connect).

4. **Duplicate event delivery** — The `preloaded` slice caused snapshot
events to be sent both in the return value and re-sent through the
channel goroutine.

## Fixes

### `coderd/chatd/chatd.go` (Subscribe method)
- **Async relay dial**: `openRelayAsync()` spawns a goroutine to dial
the remote replica. The result (channel + cancel func) is delivered on a
`relayReadyCh` channel that the select loop reads without blocking.
- **Relay reconnection**: When the relay channel closes, a 500ms timer
fires. The handler re-checks chat status from the DB and reopens the
relay if the chat is still running on a remote worker.
- **Snapshot parts via channel**: Relay snapshot + live parts are
wrapped into a single channel so they flow through the same path,
avoiding races with the select loop.

### `enterprise/coderd/chats.go` (newRemotePartsProvider)
- **Timer-based drain**: Replaced `default` with a 1-second timer. After
the first event, `Reset(0)` switches to non-blocking drain for remaining
buffered events.
- **Remove preloaded duplication**: The goroutine now only forwards new
events; snapshot events are returned to the caller directly.

## Testing

All existing tests pass:
- `TestInterruptChatBroadcastsStatusAcrossInstances`
- `TestSubscribeSnapshotIncludesStatusEvent`
- `TestSubscribeNoPubsubNoDuplicateMessageParts`
- `TestSubscribeAfterMessageID`
- `TestChatStreamRelay/RelayMessagePartsAcrossReplicas`
2026-03-02 19:57:13 -05:00
Kyle Carberry b8a74a4fcb feat: add confirmation dialog when archiving chat with workspace (#22524)
When archiving a chat that has an attached workspace, a dialog now pops
up asking whether to also delete the associated workspace.

## Changes

### New file: `ArchiveAgentDialog.tsx`
A Radix-based dialog component that appears when archiving a chat that
has a `workspace_id`. It provides:
- A checkbox to opt into deleting the associated workspace
- **Cancel** — closes without archiving
- **Archive only** — archives the chat, leaves the workspace intact
- **Archive & Delete Workspace** — archives the chat and triggers
workspace deletion (enabled only when checkbox is checked)

### Modified: `AgentsPage.tsx`
- Extracted archive logic into a `performArchive` helper
- `requestArchiveAgent` now checks if the chat has a `workspace_id`:
  - If yes, opens the `ArchiveAgentDialog`
  - If no, proceeds with archiving directly (existing behavior)
- Added `handleArchiveOnly`, `handleArchiveAndDeleteWorkspace`, and
`handleCloseArchiveDialog` handlers
- Renders the `<ArchiveAgentDialog>` at the page level

Chats without a workspace are archived immediately as before — no UX
change for those.
2026-03-02 18:52:51 -05:00
Kyle Carberry ddfe630757 refactor(chatd): replace fantasy.Agent with custom agent loop (#22507)
## Summary

Replaces fantasy's `Agent` abstraction with a direct step loop calling
`LanguageModel.Stream()`. Fantasy is retained as the provider
abstraction layer (streaming parsers, types, tool schema) but we no
longer use `fantasy.Agent`, `AgentStreamCall`, `AgentResult`, or
`StepResult`.

## Problems solved

| Problem | Before | After |
|---|---|---|
| **Sentinel prompt hack** | fantasy.Agent requires non-empty Prompt →
UUID sentinel generated and stripped in PrepareStep | Messages passed
directly to `model.Stream()` |
| **Discarded PersistStep errors** | `_ = opts.OnStepFinish(result)`
silently swallows errors | Errors propagate directly from
`PersistStep()` |
| **Shadow draft state** | ~160 LOC tracking content in parallel because
fantasy doesn't expose in-progress content on interruption |
`stepResult` owns content directly; `flushActiveState()` is trivial |
| **Nested retry layers** | fantasy's 2-attempt retry nested inside
chatretry's indefinite retry | Single `chatretry.Retry` layer |
| **Callback-mediated compaction** | Mutex + boolean flag + coordination
between OnStepFinish/PrepareStep callbacks | Inline `if` statement
between steps |
| **Duplicate compaction paths** | `compactStep()` + `maybeCompact()`
sharing ~80% logic | Single `tryCompact()` function |

## Changes

### `coderd/chatd/chatloop/chatloop.go` — Rewritten
- **Removed**: `fantasy.NewAgent()`, `AgentStreamCall`, sentinel prompt,
shadow draft state (~160 LOC of closures), `compactedMu`/`compacted`
flag, `PrepareStepResult`
- **Added**: `stepResult` struct, `processStepStream()` (stream
consumer), `executeTools()` (sequential tool execution),
`flushActiveState()` (interrupt handling), `buildToolDefinitions()`,
`toResponseMessages()`
- **Changed**: `Run()` return type from `(*fantasy.AgentResult, error)`
to `error` (callers already discarded the result)
- **Preserved**: Anthropic prompt caching, reasoning title extraction,
`extractContextLimit()`, `ErrInterrupted` semantics

### `coderd/chatd/chatloop/compaction.go` — Simplified
- Merged `compactStep()` + `maybeCompact()` → single `tryCompact()`
- Removed `[]StepResult` parameter from `generateCompactionSummary()`
(caller provides complete message list)
- Kept helper functions: `normalizedCompactionConfig`,
`contextTokensFromUsage`, `resolveContextLimit`, `shouldCompact`

### `coderd/chatd/chatd.go` — Caller updates
- Removed `AgentStreamCall` construction
- Changed `_, err = chatloop.Run(...)` to `err = chatloop.Run(...)`
- Model parameters moved from `AgentStreamCall` fields to `RunOptions`
fields

### Tests — 4 new tests
- `MidLoopCompactionReloadsMessages` — compaction fires mid-loop,
messages reloaded
- `PostRunCompactionSkippedAfterMidLoop` — no double compaction
- `MultiStepToolExecution` — tools execute between steps, results feed
next step
- `PersistStepErrorPropagates` — persistence errors propagate (was
silently discarded)
2026-03-02 18:51:57 -05:00
Kyle Carberry 7e0895a1ee fix(site): roll back optimistic message when server queues it (#22522)
## Problem

When a user sends a message while the agent is busy, the message appears
in the chat timeline as if it was sent and being processed (with the
"Thinking..." shimmer), instead of appearing in the queued messages list
above the input.

## Root Cause

`handleSend` in `AgentDetail.tsx` unconditionally injects an optimistic
user message into the conversation timeline and sets chat status to
`"pending"` **before** awaiting the server response. However, the server
can respond with `{ queued: true, queued_message: {...} }` (via
`CreateChatMessageResponse`) when the agent is already busy — meaning
the message was queued, not processed.

The client never inspected `response.queued` after the request
succeeded, so the optimistic message stayed in the timeline even though
the server queued it.

## Fix

After `sendMutation.mutateAsync(request)` resolves, check
`response.queued`. If true, roll back the optimistic message and restore
the previous chat status. The `queue_update` SSE event from the
WebSocket stream handles adding it to the queued messages list.

## Changes

- **`site/src/pages/AgentsPage/AgentDetail.tsx`**: Capture the response
from `sendMutation.mutateAsync` and roll back the optimistic message +
status when `response.queued === true`.
2026-03-02 16:31:12 -05:00
Kyle Carberry 5eebd3829f fix: use cursor-based query for chat stream notifications (#22510)
## Problem

The pubsub notification handler in `chatd` re-fetched **all** messages
from the DB on every new message notification, then filtered in Go with
`msg.ID > lastMessageID`. This grows linearly with conversation length —
every new message triggers a full table scan of that chat's history.

The `AfterMessageID` field in the pubsub notification payload was
clearly designed for cursor-based fetching, but no matching query
existed.

## Fix

- Add `GetChatMessagesByChatIDAfter` SQL query with `WHERE id >
@after_id`, so the database does the filtering instead of Go.
- Use it in the pubsub notification handler in `chatd.go`, passing
`lastMessageID` as the cursor.
- Implement the dbauthz wrapper (was a `panic("not implemented")` stub
from codegen) with the same read-check-on-parent-chat pattern as
adjacent methods.
- Add dbauthz test coverage for the new method.

**Not changed:** The initial snapshot in `Subscribe()` still loads all
messages — that's correct, since a newly-connecting client needs the
full conversation state. The waste was only in the ongoing notification
path.
2026-03-02 16:31:04 -05:00
Kyle Carberry e3c5d734ba fix(site): move gradient mask below title bar in agent detail (#22515)
The gradient mask overlay was positioned at the top of the parent
container (`absolute top-0`), causing it to overlap the title bar
instead of fading the scroll content beneath it.

**Changes:**
- Wrap the TopBar, archived banner, and gradient in a `relative z-10
shrink-0 overflow-visible` container
- Change the gradient from `top-0` to `top-full` so it anchors to the
bottom of the title bar and fades downward over the message area
2026-03-02 16:16:09 -05:00
Cian Johnston d787b3cada fix(coderd): fix error handling in deleteUserWebpushSubscription (#22500)
## Summary

`deleteUserWebpushSubscription` in `coderd/webpush.go` had incorrect
error handling that masked database errors as 404 responses.

## Bug

`GetWebpushSubscriptionsByUserID` is a `:many` query — it returns `([],
nil)` when no rows match, never `sql.ErrNoRows`. The previous `if/else
if` chain:

```go
if existing, err := api.Database.GetWebpushSubscriptionsByUserID(ctx, user.ID); err != nil && errors.Is(err, sql.ErrNoRows) {
    // dead code — :many queries never return sql.ErrNoRows
} else if idx := slices.IndexFunc(existing, ...); idx == -1 {
    // real DB errors fall through here, existing is nil, idx is -1 → 404
}
```

Any real database error (connection failure, timeout, authorization
error) fell through to the `else if` branch where `slices.IndexFunc(nil,
...)` returns `-1`, returning 404 "subscription not found" instead of
500.

## Fix

Split into two separate checks so database errors properly return 500:

```go
existing, err := api.Database.GetWebpushSubscriptionsByUserID(ctx, user.ID)
if err != nil {
    // 500
}
if idx := slices.IndexFunc(existing, ...); idx == -1 {
    // 404
}
```

## Testing

Added `TestDeleteWebpushSubscription/database_error_returns_500` which
wraps the DB store to inject an error into
`GetWebpushSubscriptionsByUserID` and asserts the handler returns 500
(not 404).
2026-03-02 21:11:20 +00:00
Kyle Carberry c4a4ad6008 feat(site): add smooth streaming text engine for LLM responses (#22503)
## Problem

LLM responses currently stream in bulk chunks — multiple `message_part`
events arrive per WebSocket frame, get batched into a single
`startTransition` state update, and render as a visual jump. This looks
janky compared to smooth character-by-character reveal.

## Solution

Port the jitter-buffer approach from
[coder/mux](https://github.com/coder/mux) into a single self-contained
file: `SmoothText.ts`.

### What's in the file

| Component | Purpose |
|---|---|
| `STREAM_SMOOTHING` constants | Tuning knobs (72–420 cps adaptive rate,
120 char max visual lag, 48 char frame cap) |
| `SmoothTextEngine` class | Pure state machine — two-clock model
(ingestion vs presentation) with budget-gated adaptive reveal |
| `useSmoothStreamingText` hook | React bridge via
`requestAnimationFrame` loop, single `useState<number>`, grapheme-safe
slicing |

### How the engine works

- **Adaptive rate:** Linear interpolation from 72 → 420 chars/sec based
on backlog pressure (how far behind the display is from ingested text)
- **Budget accumulation:** Fractional character budget accrues per RAF
tick. Only reveals when ≥1 whole character is ready. This makes it
frame-rate invariant — 60Hz and 240Hz displays reveal the same amount
over wall-clock time (tested to ≤2 char deviation)
- **Max visual lag:** Hard cap of 120 chars. If the gap exceeds this,
the visible pointer jumps forward immediately
- **Clean flush:** When streaming ends, remaining buffer appears
instantly — no trailing animation
- **Grapheme safety:** Uses `Intl.Segmenter` (with codepoint fallback)
to never split emoji mid-animation

### Integration

To wire this up, wrap the `<Response>` component in
`ConversationTimeline.tsx` with the hook:

```tsx
const SmoothedResponse: FC<{text: string; isStreaming: boolean; streamKey: string}> =
    ({ text, isStreaming, streamKey }) => {
        const { visibleText } = useSmoothStreamingText({
            fullText: text,
            isStreaming,
            bypassSmoothing: false,
            streamKey,
        });
        return <Response>{visibleText}</Response>;
    };
```

### Tests

8 engine tests covering: steady reveal, adaptive acceleration, max lag
cap, immediate flush on stream end, bypass mode, content shrink,
sub-char budget gating, and frame-rate invariance.

---------

Co-authored-by: Danielle Maywood <danielle@themaywoods.com>
2026-03-02 15:18:30 -05:00
Kyle Carberry 2f684002b8 fix(site): stay on archived chat instead of redirecting (#22505)
When archiving a chat, the frontend no longer navigates away to a
different chat. Instead it stays on the current chat and shows an
archived state.

## Changes

**AgentsPage.tsx** — Removed the redirect logic from
`requestArchiveAgent`. After a successful archive, invalidates the
individual chat query so the detail view picks up the `archived` flag
immediately.

**AgentDetail.tsx** — Detects `chatRecord.archived` and:
- Disables the chat input
- Shows a banner: "This agent has been archived and is read-only."
- Passes `isArchived` to the top bar
- Guards `handleArchiveAgentAction` against double-archiving

**AgentDetail/TopBar.tsx** — When `isArchived`:
- Shows an "Archived" badge next to the chat title
- Hides the "Archive Agent" dropdown menu item

**AgentDetail/TopBar.stories.tsx** — Added an `Archived` story variant.
2026-03-02 19:43:45 +00:00
Kyle Carberry bdc1a0e798 fix: raise diff panel drag handle z-index above sticky file headers (#22504)
## Problem

The drag handle (resize slider) on the diff right panel and the sticky
file headers inside `FilesChangedPanel` both had `z-index: 10`. Because
the sticky headers render later in the DOM and are positioned, they
painted on top of the drag handle — making it appear to go "below" the
headers when dragging.

## Fix

Bump the drag handle from `z-10` to `z-20` so it always stays above the
sticky `[data-diffs-header]` elements (`z-index: 10`).
2026-03-02 19:16:46 +00:00
Kyle Carberry 7aef0bf25e fix(chatd): increase title generation timeout from 10s to 30s (#22501)
## Problem

Production logs frequently show:

```
[debu] coderd.chats.chat-processor: failed to generate chat title
    error= generate title text: context deadline exceeded
```

## Root Cause

The title generation timeout in `maybeGenerateChatTitle` is 10 seconds.
Many LLM providers routinely exceed this under load (cold starts, rate
limits, large models). Since `chatretry` classifies `context deadline
exceeded` as non-retryable, the first timeout kills the entire attempt
with no retry.

## Fix

Increase the timeout from 10s to 30s. Title generation is async and
best-effort — it runs in a background goroutine and doesn't block the
chat response — so a longer timeout has no user-facing impact.
2026-03-02 14:11:25 -05:00
Steven Masley 583e6daaab chore: also proxy coder_session_token headers for websockets (#22499)
When using dev frontend flow
2026-03-02 12:23:28 -06:00
Jaayden Halko 3daac86efe refactor(site): add tabbed layout for single group page (#22486)
## Summary

Replaces the standalone **Settings** button on the single-group page
with a tabbed layout containing **Group members** and **Group settings**
tabs.

This uses the new figma designs here:
https://www.figma.com/design/klGTlHSPQwI4KBvAMdebrx/Customer-Usage-Controls-for-AI-Governance-Add-On?node-id=51-4907&m=dev

<img width="797" height="371" alt="Screenshot 2026-03-02 at 22 53 28"
src="https://github.com/user-attachments/assets/88d2ca8e-928f-404d-8569-ec4aba6c2ce4"
/>



### What changed

| File | Change |
|------|--------|
| `site/src/router.tsx` | Nested `settings` route under `:groupName`
layout route; added `GroupMembersPage` lazy import |
| `site/src/pages/GroupsPage/GroupPage.tsx` | Converted to shared
layout: header + tabs (`Tabs`/`TabsList`/`TabLink`) + `<Outlet />` with
context. Removed settings button and member-management code |
| `site/src/pages/GroupsPage/GroupMembersPage.tsx` | **New file** —
extracted member-management UI (add/remove members, member table) from
GroupPage; consumes data via `useOutletContext` |
| `site/src/pages/GroupsPage/GroupSettingsPage.tsx` | Switched from
independent group query to outlet context; removed duplicate
loading/error/title handling |
| `site/src/pages/GroupsPage/GroupSettingsPageView.tsx` | Removed
duplicate `ResourcePageHeader`; renders only the settings form |
| `site/e2e/tests/organizationGroups.spec.ts` | Updated selectors from
`"Settings"` link to `"Group settings"` link |

### How it works

- `:groupName` route now renders `GroupPage` as a **layout route** with
header, tabs, and `<Outlet />`.
- Index child route renders `GroupMembersPage` (member table +
add/remove).
- `settings` child route renders `GroupSettingsPage` (group settings
form).
- Shared group data + permissions are passed via React Router outlet
context, eliminating duplicate queries.
- URL structure is unchanged: `/organizations/:org/groups/:groupName`
(members) and `.../settings` (settings).

### Verification

- `pnpm exec tsc --noEmit` — passes
- `pnpm exec biome check --error-on-warnings` on all touched files —
passes
2026-03-02 18:12:40 +00:00
Kyle Carberry a33ca95df2 fix(chatd): prevent chat re-acquisition during server shutdown (#22497)
Fixes https://github.com/coder/internal/issues/1371

## Problem

`TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica` flakes
intermittently in CI. The observed failure is that the chat never
reaches `pending` status after `serverA.Close()`.

## Root cause

Race between context cancellation and the mock OpenAI server's stream
completion marker.

When `Close()` cancels the server context, the in-flight HTTP streaming
request is canceled. The mock server's handler detects this via
`req.Context().Done()` and closes its chunks channel. The mock's
`writeChatCompletionsStreaming` then writes `data: [DONE]` — the SSE
completion marker. On a loopback connection, this marker can reach the
client **before** the client's HTTP transport honors the context
cancellation.

When this happens:
1. The client sees a successful stream completion (not an error)
2. `chatloop.Run` returns `nil`
3. `processChat` falls through without error → status stays `waiting`
(the default)
4. The test expects `pending` → **flake**

## Fix

Skip writing the `[DONE]` marker when the request context is already
canceled, in both `writeChatCompletionsStreaming` and
`writeResponsesAPIStreaming`.
2026-03-02 18:00:21 +00:00
Cian Johnston 49aefdd973 chore: update test directions in AGENTS.md (#22490)
Our `AGENTS.md` previously contained this directive:

> When adding tests for new behavior, add new test cases instead of
modifying existing ones. This preserves coverage for the original
behavior and makes it clear what the new test covers.

This leads to inflated diffs and test explosions. Updating it to bias
more towards updating existing tests where applicable.

---------

Co-authored-by: Danielle Maywood <danielle@themaywoods.com>
Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
2026-03-02 17:07:38 +00:00
Kyle Carberry 0908505348 fix(chats): archive chat tree with single query instead of loop (#22496)
## Problem

When archiving an agent with subagents, the children briefly flash in
the sidebar as root-level items before disappearing. Two issues:

1. **Backend:** Archive used N+1 queries — a recursive DFS
(`archiveChatTree`, no transaction) or BFS loop (`chatd.ArchiveChat`,
N+1 queries in a tx) to walk the tree and archive each chat
individually.
2. **Frontend:** The SSE `deleted` event handler only filtered out the
parent chat from the cache. Children remained briefly, got promoted to
root-level by `buildChatTree`, then disappeared on the next re-fetch.

## Fix

**Backend:** Replace both tree-walk implementations with a single SQL
query:
```sql
UPDATE chats SET archived = true, updated_at = NOW()
WHERE id = @id OR root_chat_id = @id;
```
This leverages the existing `root_chat_id` column (already indexed) to
archive the entire tree atomically.

**Frontend:** When a `deleted` event arrives, also filter out any chats
whose `root_chat_id` matches the deleted chat, so children vanish from
the sidebar immediately with the parent.

## Changes

- `coderd/database/queries/chats.sql` — Added `ArchiveChatTreeByID`
query
- `coderd/chats.go` — Use single query, delete `archiveChatTree`
function
- `coderd/chatd/chatd.go` — Simplify `ArchiveChat` to use single query
- `coderd/database/dbauthz/dbauthz.go` — Auth wrapper for new query
- `coderd/chats_test.go` — Added `TestArchiveChat/ArchivesChildren`
subtest
- `site/src/pages/AgentsPage/AgentsPage.tsx` — Filter children in SSE
handler
- Generated files updated via `make gen`
2026-03-02 12:00:00 -05:00
Steven Masley 7bc454eed8 chore: version is 2.31 not 1.31 (#22494) 2026-03-02 16:23:09 +00:00
Cian Johnston a62f2fbfc4 feat(rbac): add AsChatd subject to replace AsSystemRestricted in chatd (#22487)
Add a new SubjectTypeChatd RBAC subject with minimal permissions:
- Chat: CRUD
- Workspace: Read
- DeploymentConfig: Read

Replace all 10 AsSystemRestricted calls in coderd/chatd/chatd.go:
- Line 890: Use AsChatd instead of AsSystemRestricted for the background
processor context.
- Subscribe() path (5 calls): Remove system escalation entirely; these
run under the authenticated user's context from the HTTP handler.
- processChat path (4 calls): Remove redundant per-call wraps; the
context already carries AsChatd from the processor start.

Add TestAsChatd verifying allowed and denied actions.

Created using Mux (Opus 4.6)
2026-03-02 15:57:04 +00:00
Kacper Sawicki 8cfb294291 fix: initialize API key LastUsed to Unix epoch instead of zero time (#22327)
## Flake Fix

Resolves https://github.com/coder/internal/issues/1301

`TestAIBridgeListInterceptions/Pagination/offset` flakes with a 500
caused by `runtime error: integer divide by zero` in `pq.ParseTimestamp`
(encode.go:430) during `GetAPIKeyByID` in the auth middleware.

### Root Cause

**PostgreSQL historical timezone formatting + fragile pq parser:**

1. **Year-0001 timestamps trigger unusual PostgreSQL formatting.** New
API keys were initialized with `LastUsed: time.Time{}` (year
0001-01-01). When the PostgreSQL server timezone is non-UTC, it applies
historical Local Mean Time (LMT) offsets for pre-1900 dates. For year
0001, this can produce timestamps with seconds in the timezone offset
like `0001-12-31 19:03:58-04:56:02`, a format the pq parser was never
designed to handle.

2. **The pq parser panics on unexpected formats.** The
fractional-seconds parser at encode.go:430 computes `fracOff` via
`strings.IndexAny`. When the timestamp has an unusual LMT format, index
arithmetic can produce `fracOff ≤ 0`, causing `int(math.Pow(10,
float64(negative))) = 0` → divide-by-zero panic.

3. **Why it is intermittent:** CI Postgres instances may have varying
timezone configs across runs. The pagination test makes 80+ API calls,
each reading `last_used` via `GetAPIKeyByID`, increasing the probability
of hitting the edge case.

4. **Ruled out pq race condition.** The decode path copies bytes to a Go
string via `string(s)` before `ParseTimestamp`, so buffer reuse cannot
corrupt the input.

### Fix

Initialize `LastUsed` to `time.Unix(0, 0).UTC()` (Unix epoch,
1970-01-01) instead of `time.Time{}` (year 0001). This avoids the entire
class of historical timestamp formatting edge cases.

**Why not `dbtime.Now()`?** The auth middleware debounces `LastUsed`
updates — it only writes when `now.Sub(key.LastUsed) > time.Hour`. Using
`dbtime.Now()` makes the key appear freshly used so the debounce never
triggers, breaking `TestPostUsers/LastSeenAt` and
`TestUsersFilter/LastSeenBeforeNow`. Unix epoch is always >1 hour in the
past, so debounce works correctly.

### Follow-up

A defensive fix should also be added to the `coder/pq` fork (guard
`fracOff ≤ 0` before the division in `ParseTimestamp`). Other year-0001
sentinel values exist across the codebase (`workspace_builds.deadline`,
`users.last_seen_at`, `workspaces.last_used_at`, etc.) and remain
theoretically vulnerable until the pq fork is hardened.
2026-03-02 16:02:01 +01:00
Danielle Maywood f2e5636a8a fix: remove extra left border from diff panel header on agents detail page (#22489) 2026-03-02 14:21:12 +00:00
Kyle Carberry ebe8c8a5b4 fix(site): scope chevron to icon hover when any child is running (#22456)
Follow-up to #22452. The previous fix only checked the chat's own
status, so a root chat in `waiting` status with actively running
sub-agents still showed the expand/collapse chevron on full-row hover.

## Problem

A root chat that's idle (`waiting`/`completed`) but has running
sub-agents would still swap its status icon for the `>` chevron on row
hover. The fix in #22452 only gated on `chat.status` being
`pending`/`running`, which doesn't cover the parent when sub-agents are
the ones executing.

## Fix

`isExecuting` now also checks whether **any direct child** is
`pending`/`running`:

```ts
const isExecuting =
  chat.status === "pending" ||
  chat.status === "running" ||
  (hasChildren &&
    childIDs.some((id) => {
      const c = chatById.get(id);
      return c?.status === "pending" || c?.status === "running";
    }));
```

When `isExecuting` is true, the chevron only appears on hover of the
icon area itself (`group-hover/icon`), not the entire row.

## New story

Added `IdleParentWithRunningChild` — verifies a `waiting` parent with a
`running` child uses icon-only hover scope for the toggle.

Co-authored-by: Coder <coder@users.noreply.github.com>
2026-03-02 09:20:46 -05:00
dependabot[bot] 80688ec221 chore: bump rust from 7e6fa79 to c0a38f5 in /dogfood/coder (#22484)
Bumps rust from `7e6fa79` to `c0a38f5`.


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=rust&package-manager=docker&previous-version=slim&new-version=slim)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-02 12:58:42 +00:00
Ethan 552f342a5b fix(codersdk): use header auth for non-browser websocket dials (#22461)
## Context
This commit is part of the fix for a downstream provider outage observed
during
`coderd_template` updates.

Observed downstream symptoms (terraform-provider-coderd):
- Template-version websocket log stream requests returned `401`:
  `GET /api/v2/templateversions/<id>/logs`.
- In older provider code (`waitForJob`), stream-init errors could
produce
`(nil, nil, err)` and then trigger a nil dereference when
`closer.Close()`
  was deferred before checking `err`.
- Net effect: template update path crashed instead of returning a
controlled
  provisioning error.

That provider panic is being hardened in the provider repo separately
(https://github.com/coder/terraform-provider-coderd/pull/308). This
commit addresses the upstream SDK auth mismatch that caused the
websocket `401`
side of the chain.

## Root cause

On deployments with host-prefixed cookie handling (dev.coder.com)
enabled
(`--host-prefix-cookie` / `EnableHostPrefix=true`), middleware rewrites
cookie
state to enforce prefixed auth cookies.

For non-browser websocket clients that still sent unprefixed
`coder_session_token` via cookie jars, this created an auth mismatch:
- cookie-based credential expected by the client path,
- but cookie normalization/stripping applied server-side,
- resulting in no usable token at auth extraction time.

## Fix in this commit

Apply the #22226 non-browser auth principle to remaining websocket
callsites in
`codersdk` by replacing cookie-jar session auth with header-token auth.

_Generated with mux but reviewed by a human_
2026-03-02 19:32:36 +11:00
blinkagent[bot] 451dedc3ee chore: replace mux■ icon with m■ icon (#22460)
Co-authored-by: blink-so[bot] <211532188+blink-so[bot]@users.noreply.github.com>
Co-authored-by: Jake Howell <jake@hwll.me>
2026-03-02 01:42:12 +00:00
blinkagent[bot] d033487fff fix(dogfood): rename mux module input add-project to add_project (#22459)
The mux module's input variable was renamed from `add-project` to
`add_project`. This updates the dogfood template to use the new name.

Ref:
https://github.com/coder/registry/blob/main/registry/coder/modules/mux/main.tf
(variable `add_project`)

Co-authored-by: blink-so[bot] <211532188+blink-so[bot]@users.noreply.github.com>
2026-03-02 00:53:51 +00:00
dependabot[bot] 300f15c98a chore: bump the coder-modules group across 3 directories with 5 updates (#22457)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore <dependency name> major version` will close this
group update PR and stop Dependabot creating any more for the specific
dependency's major version (unless you unignore this specific
dependency's major version or upgrade to it yourself)
- `@dependabot ignore <dependency name> minor version` will close this
group update PR and stop Dependabot creating any more for the specific
dependency's minor version (unless you unignore this specific
dependency's minor version or upgrade to it yourself)
- `@dependabot ignore <dependency name>` will close this group update PR
and stop Dependabot creating any more for the specific dependency
(unless you unignore this specific dependency or upgrade to it yourself)
- `@dependabot unignore <dependency name>` will remove all of the ignore
conditions of the specified dependency
- `@dependabot unignore <dependency name> <ignore condition>` will
remove the ignore condition of the specified dependency and ignore
conditions


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-02 00:38:13 +00:00
Danielle Maywood 17d3d8a2f9 fix(site): prevent layout shift when opening dropdowns on /agents page (#22448) 2026-03-01 08:01:28 +00:00
Kyle Carberry c9ed1e17fc feat(agents): add desktop notifications via VAPID web push (#22454)
## Summary

Wire VAPID web push notifications into the Agents (chat) system so users
get desktop notifications when an agent finishes running.

### Backend

- Add `webpush.Dispatcher` to `chatd.Server` and pass it through from
`coderd.Options.WebPushDispatcher`
- In `processChat()`'s deferred cleanup, dispatch a web push
notification when the chat reaches a terminal state:
  - **`waiting`** (success): "Agent has finished running."
- **`error`** (failure): the error message, or "Agent encountered an
error."
- Sub-agent chats (`ParentChatID.Valid`) are skipped to avoid
notification spam from internal delegation
- Gracefully no-ops when the dispatcher is nil (web push disabled)

### Frontend

- New `WebPushButton` component — a bell icon that uses the existing
`useWebpushNotifications` hook
  - Returns `null` when the `web-push` experiment is off
- Three states: loading spinner, green bell (subscribed), muted bell-off
(unsubscribed)
  - Tooltip + toast feedback on toggle
- Added to both the Agents page empty state top bar and the AgentDetail
top bar
- The Agents page has its own layout (no standard Navbar), so it needs
its own subscribe button

### End-to-end flow

1. User clicks the bell icon on `/agents` → browser subscribes via VAPID
2. User starts an agent chat → chat enters `running` status
3. Agent finishes → `processChat` defer sets status to `waiting`/`error`
→ dispatches web push
4. Browser service worker shows a desktop notification with the chat
title and status

---------

Co-authored-by: Coder <coder@users.noreply.github.com>
2026-02-28 23:40:17 -05:00
Kyle Carberry 68e4155fed feat(agent/filefinder): add plocate-lite file finder package (#22453)
Adds an in-memory trigram-indexed file finder package at
`agent/filefinder`, designed to power a future `FindFiles` HTTP handler
on the WorkspaceAgent.

## What it does

Fast fuzzy file search with VS Code-quality matching across millions of
files. Sub-millisecond search latency at 100K files.

## Architecture

- **Index**: append-only docs slice with trigram + prefix posting lists
- **Snapshot**: lock-free reader view via frozen slice headers +
shallow-copied deleted set
- **Search pipeline**: trigram intersection → fuzzy fallback (prefix
bucket + subsequence) → brute-force scan (capped at 5K docs)
- **Scoring**: subsequence match, basename prefix, boundary hits,
contiguous runs, depth/length penalties
- **Engine**: multi-root with fsnotify watcher (50ms batch coalescing),
atomic snapshot publishing

## Benchmarks (10K files)

| Query Type | Latency |
|---|---|
| exact_basename (`handler.go`) | ~43µs |
| short_query (`ha`) | ~7µs |
| fuzzy_basename (`hndlr`) | ~50µs |
| path_structured (`internal/handler`) | ~29µs |
| multi_token (`api handler`) | ~15µs |

## File inventory (11 files, 3273 lines)

| File | Lines | Purpose |
|---|---|---|
| `text.go` | 264 | Normalization, trigram extraction, scoring |
| `delta.go` | 128 | Index, Snapshot, CRUD operations |
| `query.go` | 272 | Query planning, search strategies, top-K merge |
| `engine.go` | 323 | Multi-root engine, watcher integration |
| `watcher_fs.go` | 201 | fsnotify wrapper with batch coalescing |
| `*_test.go` | 2085 | Unit tests, integration tests, benchmarks |

---------

Co-authored-by: Coder <coder@users.noreply.github.com>
2026-02-28 23:37:07 -05:00
Kyle Carberry 897f178a5c feat(site): replace Agent chat textarea with Lexical editor (#22449)
## Summary

Replaces the plain `<TextareaAutosize>` in the Agent chat input
(`AgentChatInput`) with a Lexical-based editor component, matching the
pattern used in [coder/blink](https://github.com/coder/blink).

## What changed

### New component: `ChatMessageInput`
`site/src/components/ChatMessageInput/ChatMessageInput.tsx`

A Lexical-powered text input that behaves as a plain-text editor with:
- **Enter** submits, **Shift+Enter** inserts newline
- Rich-text formatting disabled (Cmd+B/I/U blocked)
- Paste sanitization (strips formatting, inserts plain text)
- Undo/redo via HistoryPlugin
- Imperative ref API: `insertText()`, `clear()`, `focus()`, `getValue()`

### Updated components
- **`AgentChatInput.tsx`** — Swapped `<TextareaAutosize>` for
`<ChatMessageInput>`. Moved from controlled `value`/`onChange` to
ref-based pattern with `initialValue`/`onContentChange`.
- **`AgentDetail.tsx`** — Updated to use `useRef` for input value
tracking and `editorInitialValue` state for editor resets (edit/cancel
flows).
- **`AgentsPage.tsx`** — Updated to use `useRef` + `initialValue`
pattern.
- **`AgentChatInput.stories.tsx`** — Updated prop names.

### Why Lexical?
This lays the groundwork for features that a native `<textarea>` can't
support:
- Ghost text / inline autocomplete suggestions
- @-mentions and slash commands
- Programmatic text insertion (e.g. from speech-to-text)
- Custom inline decorators (chips, pills, badges)
- Syntax-highlighted code blocks

No adornments are added in this PR — it's a drop-in replacement that
matches existing behavior.

---------

Co-authored-by: Coder <coder@coder.com>
2026-02-28 22:18:36 -05:00
Kyle Carberry a6d2462076 fix(site): preserve running spinner on agents sidebar row hover (#22452)
When hovering over a running/pending chat in the agents sidebar, the
spinning status icon was being replaced by the expand/collapse chevron
button. This was disorienting because the spinner conveys important "in
progress" state.

## Changes

**`AgentsSidebar.tsx`**:
- Added `group/icon` scoped hover group to the icon container div
- When a chat is executing (`pending`/`running`), the chevron toggle
only appears on hover of the icon area itself, not the entire row
- Non-executing chats retain the original whole-row hover behavior (no
UX change)

**`AgentsSidebar.stories.tsx`**:
- Added `RunningChatPreservesSpinner` story verifying the spinner is
present and the toggle button starts invisible for running chats with
children

Co-authored-by: Coder <coder@users.noreply.github.com>
2026-02-28 22:05:18 -05:00
Kyle Carberry fb6bf3a568 fix(agent): wire updateCommandEnv into process manager (#22451)
## Problem

The `agentproc` process manager spawns processes with only
`os.Environ()`, missing agent-level environment variables like
`GIT_ASKPASS`, `CODER_*`, and `GIT_SSH_COMMAND` that are injected by the
agent's `updateCommandEnv` function. This means processes started
through the HTTP process API (used by chat tools) cannot authenticate
git operations via the Coder gitaskpass helper.

By contrast, SSH sessions get the full agent environment because the SSH
server calls `updateCommandEnv` via its `UpdateEnv` config hook.

## Fix

Wire the agent's `updateCommandEnv` hook into the process manager so all
spawned processes receive the full agent environment. The hook is:

- Passed as a parameter through `NewAPI` → `newManager`
- Called in `manager.start()` with `os.Environ()` as the base, producing
the same enriched env that SSH sessions get
- Gracefully falls back to `os.Environ()` if the hook returns an error

Request-level env vars (`req.Env`, set by chat tools) are still appended
last and take precedence.

## Changes

- `agent/agentproc/process.go`: Add `updateEnv` field to manager, call
it when building process env
- `agent/agentproc/api.go`: Accept `updateEnv` parameter in `NewAPI`
- `agent/agent.go`: Pass `a.updateCommandEnv` when creating the process
API
- `agent/agentproc/api_test.go`: Add `UpdateEnvHook` and
`UpdateEnvHookOverriddenByReqEnv` tests

Co-authored-by: Coder <coder@coder.com>
2026-02-28 21:58:59 -05:00
Kyle Carberry 533b90a3a4 fix: resolve chat title update race conditions and improve resilience (#22450)
## Problem

Chat titles sometimes don't update in the UI. The generated AI title
gets stuck as the fallback (first 6 words of the message) even though
the backend successfully generates a proper title.

## Root Causes

### 1. Cancelable context used during cleanup DB read (P0)
In `processChat`, the deferred cleanup re-reads the chat from the DB to
pick up the AI-generated title for the `status_change` pubsub event. But
it used the cancelable `ctx` instead of `cleanupCtx`:

```go
// Before — ctx may already be canceled here
if freshChat, readErr := p.db.GetChatByID(ctx, chat.ID); readErr == nil {
```

When the context is canceled, the DB read fails silently and the
`status_change` event carries the stale fallback title.

### 2. Title goroutine not tracked by inflight WaitGroup (P2)
The `maybeGenerateChatTitle` goroutine was fire-and-forget — not tracked
by `p.inflight`. During graceful shutdown, the server could exit before
the goroutine completes its DB write or pubsub publish.

### 3. No recovery when watchChats() WebSocket misses events
The frontend relies entirely on the `watchChats()` SSE connection for
title updates. If the connection drops or misses events, titles never
recover — the only fix was a full page reload.

## Fixes

1. **Use `cleanupCtx`** for the `GetChatByID` call and logger in the
deferred cleanup block.
2. **Track the title goroutine** with `p.inflight.Add(1)` / `defer
p.inflight.Done()` so shutdown waits for it.
3. **Invalidate chats query** on WebSocket open/close/error events so
missed updates are recovered via refetch. Also enable
`refetchOnWindowFocus` for the chats query.

Co-authored-by: Coder <coder@users.noreply.github.com>
2026-02-28 21:38:16 -05:00
Kyle Carberry 1c71fd69f6 fix: workspace auto-refresh during the chat flow (#22447) 2026-02-28 19:07:17 -05:00
Kyle Carberry 948fd0fc06 test(site): add comprehensive ChatContext store and integration tests (#22444)
## Summary

Export `createChatStore`, `ChatStore`, and `ChatStoreState` from
`ChatContext.ts` so the pure store logic can be unit tested directly
without React rendering overhead.

## Changes

### Production code (3-line change)
- Added `export` to `ChatStoreState`, `ChatStore`, and `createChatStore`
in `ChatContext.ts`

### chatStore.test.ts — 35 pure store unit tests (runs in ~6ms)
Covers every store method directly with synchronous, zero-React tests:
- `replaceMessages`: population, ordering, undefined handling,
dedup/no-emit
- `upsertDurableMessage`: insert, duplicate detection, value-change
update, optimistic placeholder removal (negative IDs), role-scoped
cleanup, in-place update without reorder
- `setChatStatus`: set, null clear, idempotency
- `setStreamError` / `clearStreamError`: set, clear, no-op guards, dedup
- `setRetryState` / `clearRetryState`: set, clear, no-op guard
- `setSubagentStatusOverride`: single, accumulation, dedup, overwrite
- `setQueuedMessages`: set, undefined handling, ID-based dedup
- `clearStreamState`: clear, no-op guard
- `applyMessagePart` / `applyMessageParts`: text, append, batch, empty
no-op
- `resetTransientState`: clears all transient state, preserves messages,
no-op guard
- `subscribe`: unsubscribe lifecycle, multiple subscribers

### ChatContext.test.tsx — 8 new integration tests
WebSocket event handling that was previously untested:
- **Error events**: sets chatStatus to error, populates streamError,
clears retryState, calls setChatErrorReason; uses fallback when error
has blank text
- **Retry events**: populates retryState; status transition to running
clears retryState
- **Subagent status overrides**: status events with different chat_id go
to subagentStatusOverrides, not main chatStatus
- **WebSocket disconnect**: sets streamError; preserves existing error
on disconnect
- **Status transitions**: clears chatErrorReason on non-error status

### Test infrastructure improvements
- Added `emitError()` helper to MockSocket for testing WebSocket
disconnect
- Added `vi.mocked(watchChat).mockReset()` to `afterEach` for reliable
test isolation between tests that use `mockReturnValueOnce`

## Test results
```
✓ chatStore.test.ts (35 tests) 6ms
✓ ChatContext.test.tsx (23 tests) 107ms
  58 passed (58)
```

---------

Co-authored-by: Coder <coder@users.noreply.github.com>
2026-02-28 17:14:58 -05:00
Kyle Carberry 2abe55549c fix: return in-flight chats to pending on server shutdown (#22443)
When a chatd server shuts down (`Close()`), the server context is
canceled. Previously, in-flight chats would be marked as `error` because
the `context.Canceled` error was not distinguished from actual
processing failures.

This adds `isShutdownCancellation()` to detect when the error is caused
by the server context being canceled (as opposed to a chat-specific
cancellation like `ErrInterrupted`). When detected, the chat status is
set to `pending` with no `last_error`, allowing another replica to pick
it up and retry.

Extracted from #22440 — only the context cancellation bug fix, no
chattest changes.
2026-02-28 17:14:11 -05:00
Danielle Maywood 7860b99597 refactor(site): refactor AgentsPage createPortal soup (#22438) 2026-02-28 22:11:11 +00:00
Kyle Carberry 5945febf06 feat(agent): add fuzzy whitespace matching to edit_files tool (#22446)
Inspired by openai/codex's `apply_patch` implementation, this changes
the `edit_files` search-and-replace to use a cascading match strategy
when the exact search string isn't found:

1. **Exact substring match** (byte-for-byte) — existing behavior,
unchanged
2. **Line-by-line match ignoring trailing whitespace** — handles
trailing spaces/tabs the LLM omits
3. **Line-by-line match ignoring all leading/trailing whitespace** —
handles tabs-vs-spaces and wrong indentation depth

## Problem

When the chat agent uses `edit_files`, it generates a search string that
must match the file content exactly. LLMs frequently get whitespace
wrong:
- Emitting spaces when the file uses tabs (or vice versa)
- Getting the indentation depth wrong by one or more levels
- Omitting trailing whitespace that exists in the file

When this happens, the edit silently does nothing, and the agent falls
into a retry loop using `cat -A` to diagnose the exact whitespace
characters.

## Solution

Adopted the same cascading fuzzy match strategy that [openai/codex uses
in
`seek_sequence.rs`](https://github.com/openai/codex/blob/main/codex-rs/apply-patch/src/seek_sequence.rs):

- Pass 1: exact match (existing behavior)
- Pass 2: `TrimRight` each line before comparing (trailing whitespace
tolerance)
- Pass 3: `TrimSpace` each line before comparing (full indentation
tolerance)

When a fuzzy match is found, the matched lines in the original file are
replaced with the replacement text. This preserves surrounding content
exactly.

## Changes

- `agent/agentfiles/files.go`: Replaced `icholy/replace` streaming
transformer with in-memory `fuzzyReplace` + helper functions
(`seekLines`, `spliceLines`)
- `agent/agentfiles/files_test.go`: Added 6 new test cases covering
trailing whitespace, tabs-vs-spaces, different indent depths, exact
match preference, no-match behavior, and mixed whitespace multiline
edits
- Removed `icholy/replace` dependency from go.mod/go.sum

---------

Co-authored-by: Kyle Carberry <kylecarbs@users.noreply.github.com>
2026-02-28 17:02:57 -05:00
Kyle Carberry 22d4539a7a fix(chatd): clear stream buffer after each step is persisted (#22445)
The in-memory stream buffer accumulated message-part events for the
entire duration of a chat run. Late-joining subscribers received all
buffered parts even though the backing messages had already been
committed to the database, wasting memory and potentially duplicating
content.

Clear the buffer at the end of each `persistStep` call so that only
in-flight (uncommitted) parts remain in the buffer.
2026-02-28 16:51:04 -05:00
Kyle Carberry 34d9392e37 chore(db): remove workspace_agent_id from chats table (#22442)
## Summary

Remove the `workspace_agent_id` column from the `chats` table and
dynamically look up the first workspace agent instead.

## Problem

When a workspace is stopped and restarted, the workspace agent gets a
new ID. The `workspace_agent_id` stored on the chat at creation time
becomes stale, making the agent unreachable. This caused chats to break
after workspace restarts.

## Solution

Instead of persisting the agent ID, dynamically look up the first agent
from the workspace's latest build via
`GetWorkspaceAgentsInLatestBuildByWorkspaceID` whenever an agent
connection is needed. The `workspace_id` on the chat remains stable
across restarts.

This behavior may be refined later (e.g., agent selection heuristics),
but picking the first agent resolves the immediate breakage.

## Changes

- **Migration 000425**: Drop `workspace_agent_id` column from `chats`
- **SQL queries**: Remove `workspace_agent_id` from `InsertChat` and
`UpdateChatWorkspace`
- **chatd.go**: `getWorkspaceConn` and `resolveInstructions` now look up
agents dynamically from workspace ID
- **chatd.go**: Remove `refreshChatWorkspaceSnapshot` (no longer needed)
- **createworkspace.go**: Stop persisting agent ID when associating
workspace with chat
- **subagent.go**: Stop passing agent ID to child chats
- **SDK/frontend**: Remove `WorkspaceAgentID` / `workspace_agent_id`
from Chat type

---------

Co-authored-by: Kyle Carberry <kylecarbs@gmail.com>
2026-02-28 16:46:51 -05:00
Kyle Carberry c316d0a3e7 fix(chatd): improve subagent tool descriptions and strip tools from child agents (#22441)
Two changes:

1. **Gate subagent tools behind `!chat.ParentChatID.Valid`** so child
agents never receive `spawn_agent`, `wait_agent`, `message_agent`, or
`close_agent`. Previously all 4 tools were given to every chat.
`spawn_agent` would fail at runtime ("delegated chats cannot create
child subagents") but the other 3 had no guard at all — meaning a child
could theoretically operate on sibling chats. Removing the tools
entirely is cleaner and saves context window.

2. **Rewrite tool descriptions to explain *when* to use them**, not just
what they do. `spawn_agent` now says to use it for clearly scoped,
independent, self-contained tasks (e.g. fixing a specific bug, writing a
single module, running a migration) and explicitly says *not* to use it
for simple operations you can handle with
`execute`/`read_file`/`write_file`. It also states that child agents
cannot spawn their own subagents. The other 3 tools get similar
guidance-oriented descriptions.

Co-authored-by: Coder <coder@users.noreply.github.com>
2026-02-28 16:30:25 -05:00
Kyle Carberry eec0b299e8 fix(site): add chromatic ignore to shimmer component (#22439)
The shimmer component has an infinitely repeating animation that causes
Chromatic snapshot diffs on every run. Adding `data-chromatic="ignore"`
to prevent false positives, consistent with how other animated
components in the codebase handle this (e.g. `Spinner`, `Alert`,
`SyntaxHighlighter`).

Co-authored-by: Coder <coder@users.noreply.github.com>
2026-02-28 15:17:04 -05:00
Kyle Carberry c5619746d1 fix(chat): fix stream state discrepancies between frontend and backend (#22437)
## Summary

Fixes four frontend↔backend discrepancies in chat stream state
management that could cause duplicate content, UI flicker, and stale
stream state.

### Backend fixes (`coderd/chatd/chatd.go`)

**1. No-pubsub path double-replayed message_part events**

`Subscribe()` built an `initialSnapshot` containing `message_part`
events from `localSnapshot`, then the no-pubsub goroutine replayed the
same `localSnapshot` into the `mergedEvents` channel. Since `streamChat`
sends the snapshot first then reads the channel, the frontend received
every `message_part` twice. `applyMessagePartToStreamState` doesn't
deduplicate — text gets concatenated, so content appeared doubled.

Fix: Only forward live `localParts` in the no-pubsub goroutine; the
snapshot already contains the historical events.

**2. Snapshot missing status event**

The initial snapshot never included a `status` event. The frontend's
`shouldApplyMessagePart()` gates on status (`pending`/`waiting`), but
the initial status came from a separate REST query via `useEffect`.
During the race window between snapshot arrival and REST resolution,
`message_part` events could be incorrectly accepted or rejected.

Fix: Prepend a `status` event to the snapshot after loading the chat
from DB, so the frontend has the authoritative status from the very
first batch.

### Frontend fixes (`ChatContext.ts`)

**3. Scheduled stream reset not canceled by subsequent message_parts**

When a `message` event arrived, `scheduleStreamReset()` queued
`clearStreamState` via `requestAnimationFrame`. If new `message_part`
events arrived in the next WebSocket frame before the rAF fired, they
were pushed to `pendingMessageParts` without canceling the scheduled
reset. The rAF would fire between frames, clearing stream state, then
the next flush would re-populate it — causing a visible flash.

Fix: Call `cancelScheduledStreamReset()` when accumulating
`message_part` events.

**4. startTransition race with synchronous clearStreamState**

`flushMessageParts` wrapped `applyMessageParts` in `startTransition`,
which React can defer. If a `status: "waiting"` event arrived in the
same batch after `message_part` events, the status handler cleared
stream state synchronously, but the deferred `applyMessageParts`
callback could fire afterward and re-populate it.

Fix: Re-check `shouldApplyMessagePart()` inside the `startTransition`
callback at execution time.

### Tests added

- **Go**: `TestSubscribeSnapshotIncludesStatusEvent` — asserts the first
snapshot event is a status event
- **Go**: `TestSubscribeNoPubsubNoDuplicateMessageParts` — asserts the
events channel doesn't replay snapshot events
- **TS**: `cancels scheduled stream reset when message_part arrives
after message` — verifies stream state survives a [message,
message_part] batch
- **TS**: `does not apply message parts after status changes to waiting`
— verifies deferred applyMessageParts respects status transitions
2026-02-28 13:35:23 -05:00
Kyle Carberry a621c3cb13 feat(agent): add process execution API and rewrite execute tool (#22416)
## Summary

Adds a new agent-side process management HTTP API and rewrites the chat
execute tool to use it instead of SSH sessions.

## What changed

### New agent/agentproc/ package

- **headtail.go** — Thread-safe io.Writer with bounded memory (16KB head
+ 16KB tail ring buffer). Provides LLM-ready output with truncation
metadata and long-line truncation at 2048 bytes.
- **headtail_test.go** — 16 tests including race detector coverage for
concurrent writes.
- **process.go** — Manager + Process types for lifecycle management
using agentexec.Execer for proper OOM/nice scores.
- **api.go** — HTTP API following the agentfiles chi router pattern. 4
endpoints: start, list, output, signal.

### Agent wiring (agent/agent.go, agent/api.go)

Mounts the process API at /api/v0/processes, mirroring how agentfiles is
mounted.

### SDK (codersdk/workspacesdk/agentconn.go)

4 new AgentConn interface methods + 7 request/response types:
- StartProcess, ListProcesses, ProcessOutput, SignalProcess

### Execute tool rewrite (coderd/chatd/chattool/execute.go)

- SSH to Agent API: conn.StartProcess() + conn.ProcessOutput() polling
- New parameters: workdir, run_in_background
- Structured response: success, exit_code, wall_duration_ms, error,
truncated, note, background_process_id
- Non-interactive env vars: GIT_EDITOR=true, TERM=dumb, NO_COLOR=1,
PAGER=cat, etc.
- Output truncation: HeadTailBuffer caps at 32KB for LLM consumption
- File-dump detection with advisory notes suggesting read_file
- Default timeout: 60s to 10s
- Foreground polling: 200ms intervals until exit or timeout

## Architecture

State lives on the agent, surviving coderd failover and instance
changes. Any coderd replica can query any agent via HTTP over tailnet.
2026-02-28 12:33:52 -05:00
Kyle Carberry 0ad2f9ecd7 feat(chatd): persist last_error on chats table (#22436)
Adds a nullable `last_error` column to the `chats` table so error
reasons survive page reloads.

**Backend:**
- Migration adds `last_error TEXT` (nullable) to chats
- `UpdateChatStatus` writes the error reason when status transitions to
`error`, clears it (NULL) on recovery
- `convertChat` maps `sql.NullString` to `*string` in the SDK

**Frontend:**
- Sidebar falls back to `chat.last_error` when no stream error reason is
cached
- Chat detail page does the same for `persistedErrorReason`
- Fixtures updated for new required field
2026-02-28 12:27:26 -05:00
Danielle Maywood d412972cd5 refactor(site): use diff library for inline tool diffs (#22423)
Replaces the hand-rolled LCS diffing in `buildEditDiff` and the
manual patch-string assembly in `buildWriteFileDiff` with
[`Diff.createPatch()`](https://www.npmjs.com/package/diff) from the
`diff` npm package.

Both functions now just call `Diff.createPatch()` and feed the result
straight into `parsePatchFiles()`, removing all the manual line
splitting, prefix tagging, hunk-header arithmetic, and trailing-newline
cleanup.

### Changes
- Add `diff` as a dependency
- `buildWriteFileDiff`: replaced ~20 lines of manual patch assembly
  with a single `Diff.createPatch()` call
- `buildEditDiff`: replaced ~60 lines (line splitting, `Diff.diffLines`
  → prefixed strings, hunk counting) with a `Diff.createPatch()` call
  per edit
- Removed the `chunkLines` helper and the `diffLines` wrapper +
  its test block

Net: +21 / -157 lines across source and tests.
2026-02-28 16:31:51 +00:00
Danielle Maywood 607c25b07e fix(site): remove optimistic message when real server message arrives on agents page (#22432) 2026-02-28 16:29:42 +00:00
Danielle Maywood bde772cfa3 fix(site): filter agents workspace dropdown to owner and show owner/name format (#22409) 2026-02-28 11:01:23 +00:00
Danielle Maywood 31ad3cdd0c fix(site): wrap long lines in agents diff panel (#22414)
The diff view on the `/agents` page had no way to handle lines wider
than the panel. The `@pierre/diffs` library supports an `overflow`
option — switching it from `"scroll"` (the shared default) to `"wrap"`
for the side panel makes long lines wrap naturally instead of being
clipped.

Also adds a long import line to the Storybook sample diff so the
wrapping behavior is easy to verify visually.
2026-02-28 10:33:06 +00:00
Danielle Maywood 1dec6da358 refactor(site): simplify AgentChatInput into a controlled component (#22426) 2026-02-28 10:32:48 +00:00
Jaayden Halko f95ae63c96 feat: require typed confirmation for license removal (#22082)
## Summary
Adds a typed-confirmation step before deleting a deployment license to
reduce accidental removals.

<img width="457" height="440" alt="Screenshot 2026-02-13 at 15 31 58"
src="https://github.com/user-attachments/assets/b13320a7-4b10-43fa-ab01-56f3284435b6"
/>

## Changes
- Swapped the license removal dialog from `ConfirmDialog` to
`DeleteDialog`, requiring the admin to type the license ID before
enabling **Remove**.
- Added interaction coverage to verify the confirmation guard.
2026-02-28 07:58:10 +00:00
Matt Vollmer 60793aa277 fix(site): fix double-tap required to open chat on mobile (#22430) 2026-02-28 02:17:08 -05:00
Jeremy Ruppel 816d99e46c flake: increase test timeout for TemplateVersionEditorPage tests (#22412)
TemplateVersionEditorPage tests have been flaking since I ported them to
vitest in 99a4ecd. Turns out our test timeout on jest is 20s (presumably
for these sorts of page-level journey tests). I kinda like the current
5s timeout as it forces us to write speedy tests, but I think in this
case it's unavoidable and makes sense to lengthen the timeout just for
these tests.

Hopefully fixes coder/internal#1369

You may want the whitespaceless diff here:
https://github.com/coder/coder/pull/22412/changes?w=1
2026-02-27 19:16:09 -05:00
Kyle Carberry 256284b7fe fix: add sticky headers to the git diff (#22425)
<!--

If you have used AI to produce some or all of this PR, please ensure you
have read our [AI Contribution
guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING)
before submitting.

-->
2026-02-27 19:03:11 -05:00
Kyle Carberry 2bdacae5f5 feat(chatd): add LLM stream retry with exponential backoff (#22418)
## Summary

Adds automatic retry with exponential backoff for transient LLM errors
during chat streaming and title generation. Inspired by
[coder/mux](https://github.com/coder/mux)'s retry mechanism.

## Key Behaviors

- **Infinite retries** with exponential backoff: 1s → 2s → 4s → ... →
60s cap
- **Deterministic delays** (no jitter)
- **Error classification**: retryable (429, 5xx, overloaded, rate limit,
network errors) vs non-retryable (auth, quota, context exceeded, model
not found, canceled)
- **Retry status published to SSE stream** so frontend can show
"Retrying in Xs..." UI
- **Title generation** retries silently (best-effort, nil onRetry
callback)

## New Package: `coderd/chatd/chatretry/`

| File | Purpose |
|------|---------|
| `classify.go` | `IsRetryable(err)` and `StatusCodeRetryable(code)` |
| `backoff.go` | `Delay(attempt)` — exponential doubling with 60s cap |
| `retry.go` | `Retry(ctx, fn, onRetry)` — infinite loop with
context-aware timer |

## Test Helpers: `coderd/chatd/chattest/errors.go`

Anthropic and OpenAI error response builders for use in chattest
providers:
- `AnthropicErrorResponse()`, `AnthropicOverloadedResponse()`,
`AnthropicRateLimitResponse()`
- `OpenAIErrorResponse()`, `OpenAIRateLimitResponse()`,
`OpenAIServerErrorResponse()`

## SDK Changes: `codersdk/chats.go`

- New `ChatStreamEventType: "retry"`
- New `ChatStreamRetry` struct with `Attempt`, `DelayMs`, `Error`,
`RetryingAt` fields
- TypeScript types auto-generated

## Changed Files

- `coderd/chatd/chatloop/chatloop.go` — wraps `agent.Stream()` in
`chatretry.Retry()`
- `coderd/chatd/chatd.go` — publishes retry events to SSE stream with
logging
- `coderd/chatd/title.go` — wraps `model.Generate()` in silent retry
- `coderd/chatd/chattest/anthropic.go` / `openai.go` — error injection
support

## Tests

42 tests covering classification (33), backoff (9), and retry scenarios
(8).
2026-02-27 18:34:33 -05:00
Kyle Carberry 4b5ec8a9a4 feat: add diff_status_change event to /chats/watch pubsub stream (#22419)
## Summary

Adds a new `diff_status_change` event kind to the `/chats/watch` pubsub
stream so the sidebar can update diff status (PR created, files changed,
branch info) without a full page reload.

### Problem

When a chat's diff status changes (e.g. PR created via GitHub, git
branch pushed), the sidebar didn't update because:
1. The backend `publishChatPubsubEvent` didn't include diff status data
2. The frontend watch handler only merged `status`, `title`, and
`updated_at` from events

### Solution

A **notify-only** approach: a new `ChatEventKindDiffStatusChange` event
kind tells the frontend "diff status changed for chat X" — the frontend
then invalidates the relevant React Query cache entries to re-fetch.

### Backend changes

- **`coderd/pubsub/chatevent.go`**: New `ChatEventKindDiffStatusChange =
"diff_status_change"` constant
- **`coderd/chatd/chatd.go`**: New `PublishDiffStatusChange(ctx,
chatID)` method on `Server`
- **`coderd/chats.go`**: New `publishChatDiffStatusEvent` helper.
Published from:
- `refreshWorkspaceChatDiffStatuses` — after each chat's diff status is
refreshed via GitHub API
- `storeChatGitRef` — after persisting git branch/origin info from
workspace agent

### Frontend changes

- **`AgentsPage.tsx`**: Handle `diff_status_change` event by
invalidating `chatDiffStatusKey` and `chatDiffContentsKey` queries
- **`ChatContext.ts`**: Remove redundant diff status invalidation that
fired on every chat status change (the new event kind handles this
properly)
2026-02-27 18:06:54 -05:00
Kyle Carberry b0c6a6dc25 fix(site): optimistic message on agent chat submit (#22422)
## Problem

When sending a message in the agent detail chat, the text lingered in
the input textarea while the HTTP POST round-tripped to the server. Only
after the server responded did the input clear and the message appear in
the timeline (via WebSocket). This created a noticeable delay where the
user couldn't start typing their next message.

## Solution

**Optimistic input clear** (`AgentChatInput.tsx`):
- Clear the textarea and editing state *immediately* on submit, before
awaiting the network call.
- Capture the input text beforehand so it can be restored in the `catch`
block if the request fails.

**Optimistic user bubble** (`AgentDetail.tsx`):
- Inject a temporary `ChatMessage` (with a negative ID) into the chat
store so the user's message bubble appears in the timeline instantly.
- Set chat status to `pending` and clear stream state, mirroring the
existing edit-message path.
- On error, roll back: remove the optimistic message and restore the
previous chat status.

The real message arrives via the WebSocket stream and
`upsertDurableMessage` replaces the optimistic entry naturally (the
server message has a positive ID, so it's inserted alongside; the
optimistic negative-ID message gets cleaned up when `replaceMessages` is
called with the authoritative message list from the next query
invalidation).

## Testing

- Type a message and press Enter — input clears and bubble appears
immediately.
- Simulate a network error — input text is restored, optimistic bubble
is removed.
- Edit an existing message — unchanged behavior (already had optimistic
updates).
- Queue a message while streaming — unchanged behavior.
2026-02-27 17:49:53 -05:00
Kyle Carberry 5fb644a6cd feat(site): add keyboard shortcuts to agents page (#22417)
Adds two keyboard shortcuts to the agents page:

- **Escape** — Interrupts the running agent when viewing a chat detail
page. Only fires when focus is outside text inputs/textareas so it
doesn't conflict with the existing edit-cancel Escape handler in the
chat input.
- **Ctrl+N / Cmd+N** — Navigates to create a new agent. Also skipped
when focus is in a text input/textarea.

Both keybindings are implemented in a new `useAgentsPageKeybindings.ts`
hook file:
- `useAgentsPageKeybindings` — used in `AgentsPage.tsx` for Ctrl+N
- `useAgentDetailKeybindings` — used in `AgentDetail.tsx` for Escape →
interrupt
2026-02-27 17:33:43 -05:00
Kyle Carberry 12083441e0 feat(chats): archive chats instead of hard-deleting them (#22406)
## Summary

The UI has always labeled the action as "Archive agent" but the backend
was performing a hard `DELETE`, permanently destroying chats and all
their messages.

This change replaces the hard delete with a soft archive, consistent
with the pattern used by template versions.

## Changes

### Database
- **Migration 000423**: Add `archived boolean DEFAULT false NOT NULL`
column to `chats` table
- Replace `DeleteChatByID` query with `ArchiveChatByID` (`UPDATE SET
archived = true`)
- Add `UnarchiveChatByID` query (`UPDATE SET archived = false`)
- Filter archived chats from `GetChatsByOwnerID` (`WHERE archived =
false`)

### API
- Remove `DELETE /api/experimental/chats/{chat}`
- Add `POST /api/experimental/chats/{chat}/archive` — archives a chat
and all its descendants
- Add `POST /api/experimental/chats/{chat}/unarchive` — unarchives a
single chat (API only, no UI yet)

### Backend
- `archiveChatTree()` recursively archives child chats (replaces
`deleteChatTree()` which hard-deleted)
- Chat daemon's `ArchiveChat()` archives the full chat tree in a
transaction
- Authorization uses `ActionUpdate` instead of `ActionDelete`

### SDK
- Replace `DeleteChat()` with `ArchiveChat()` and `UnarchiveChat()`
- Add `Archived` field to `Chat` struct

### Frontend
- `archiveChat` API call uses `POST .../archive` instead of `DELETE`
- No UI changes — the "Archive agent" button now actually archives
instead of deleting

## Design Decision

This follows the **template version archive pattern** (Pattern B in the
codebase):
- `archived boolean` column (not `deleted boolean`)
- Dedicated `POST .../archive` and `POST .../unarchive` routes (not
repurposing `DELETE`)
- Reversible — users can unarchive via the API (UI for this will come
later)
2026-02-27 16:46:19 -05:00
Kyle Carberry 52dad56462 fix(coderd): refresh OAuth token before GitHub API calls in chat diff (#22415)
## Problem

`resolveChatGitHubAccessToken` reads the `OAuthAccessToken` directly
from the database without refreshing it. When the token expires, GitHub
returns "bad credentials" and the chat diff features break.

## Fix

Call `config.RefreshToken()` before returning the token — the same code
path used by `provisionerdserver` when handing tokens to provisioners.

- Builds a map of provider ID → `*externalauth.Config` during the
existing config iteration
- After fetching the `ExternalAuthLink` from the DB, calls
`cfg.RefreshToken()` if a matching config exists
- On refresh failure, falls through to the existing token (GitHub tokens
without expiry still work) with a debug log
2026-02-27 16:37:17 -05:00
Kyle Carberry 360df1d84f fix(chatd): publish streaming message_part events during compaction (#22410)
## Problem

Context compaction in chatd persisted durable messages for the
`chat_summarized` tool call and result via `publishMessage`, but never
published `message_part` streaming events via `publishMessagePart`. This
meant connected clients had no streaming representation of the
compaction.

The client's `streamState` (built entirely from `message_part` events in
`streamState.ts`) never saw the compaction tool call, so:

- No **"Summarizing..."** running state was shown to the user during
summary generation (which can take up to 90s).
- The durable `message` events arrived after or interleaved with the
`status: waiting` event, causing the tool to appear as "Summarized" with
the chat appearing to just stop.

## Fix

### 1. `CompactionOptions.OnStart` callback (chatloop)

Added an `OnStart` callback to `CompactionOptions`, called in
`maybeCompact` right before `generateCompactionSummary` (the slow LLM
call). This gives `chatd` a hook to publish the tool-call `message_part`
immediately when compaction begins.

### 2. Tool-result streaming part (chatd)

`persistChatContextSummary` now publishes a tool-result `message_part`
before the durable `message` events, so clients transition from
"Summarizing..." to "Summarized" before the status change arrives.

### Event ordering is now:
1. `message_part` (tool call via `OnStart`) — client shows
"Summarizing..."
2. LLM generates summary (up to 90s)
3. `message_part` (tool result) — client shows "Summarized" in stream
state
4. `message` (assistant) — durable message persisted, stream state
resets
5. `message` (tool) — durable tool result persisted
6. `status: waiting` — chat transitions to idle

## Tests

- **`OnStartFiresBeforePersist`**: Verifies callback ordering is
`on_start` → `generate` → `persist`.
- **`OnStartNotCalledBelowThreshold`**: Verifies `OnStart` is not called
when context usage is below the compaction threshold.
2026-02-27 16:33:39 -05:00
blinkagent[bot] 8bb80b060e fix(e2e): fix flaky verifyParameters assertion in updateWorkspace test (#22413)
## Problem

The `update workspace, new required, mutable parameter added` e2e test
has been flaking consistently
([internal#1328](https://github.com/coder/internal/issues/1328)). The
error:

```
Error: Timed out 5000ms waiting for expect(locator).toHaveValue(expected)
Locator: getByTestId('parameter-field-Sixth parameter').locator('input')
Expected string: "99"
Received string: ""
```

## Root Cause

A race between page navigation and data hydration in `verifyParameters`:

1. The page navigates with `waitUntil: "domcontentloaded"` which does
not wait for API responses to settle
2. React Query may serve stale cached workspace data initially (from
before the update), causing the form to render with empty/old parameter
values
3. The `toHaveValue` assertion uses the default `actionTimeout` of
5000ms which isn't enough time for fresh data to arrive and the form to
re-render

## Fix

- Switch `verifyParameters` navigation to `waitUntil: "networkidle"` to
ensure API responses (workspace data, build parameters) are settled
before the form renders
- Increase the `toHaveValue` timeout to 15s to handle cases where
dynamic parameters hydrate slowly after initial render

Fixes coder/internal#1328

---------

Co-authored-by: blink-so[bot] <211532188+blink-so[bot]@users.noreply.github.com>
2026-02-27 16:32:16 -05:00
Kyle Carberry 1a87e74574 fix(site): prevent race conditions when switching chats on agents page (#22404)
## Problem

When switching between chats on the agents page, stream parts could be
lost or applied to the wrong chat due to several race conditions in
`ChatContext.ts`:

1. **`startTransition` deferred parts escape cleanup** —
`startTransition(() => store.applyMessageParts(parts))` defers the state
update. If a chat switch happens between `flushMessageParts` being
called and the transition executing, old-chat parts could apply after
`resetTransientState()` has already cleared stream state for the new
chat.

2. **`message` event has no `chat_id` filter** — Unlike `message_part`,
`queue_update`, and `status` events, the `message` event handler did not
check `streamEvent.chat_id`. While the server-scoped WebSocket makes
this safe in practice, it's an inconsistency in defensive programming.

3. **Brief stale message window on switch** — Between `chatID` changing
and `replaceMessages()` firing (after the query resolves), the store
held old-chat messages while the new WebSocket was already connected.

## Changes

### `ChatContext.ts`
- Added `activeChatIDRef` to track the currently active chat ID
- Guard `startTransition` callback: check `activeChatIDRef` before
applying message parts, discarding them if the chat has switched
- Added `chat_id` filter to `message` event handler, matching the
pattern used by all other event types
- Added `store.replaceMessages([])` to the chatID-change effect so
messages are cleared immediately on switch

### `ChatContext.test.tsx`
Four new tests covering the chat-switch lifecycle:
- WebSocket closure and state reset when chatID changes
- `message` event filtering by `chat_id`
- `startTransition` deferred parts discarded after switch
- Messages cleared immediately before new query resolves

All 13 tests pass (8 existing + 4 new + 1 existing).
2026-02-27 15:33:16 -05:00
Kyle Carberry bb97ba727f fix(coderd): allow non-admin users to list chat model configs (#22407)
## Problem

Non-admin users of the Agents (chat) feature send `model_config_id:
"00000000-0000-0000-0000-000000000000"` (nil UUID) when creating chats,
because the `GET /api/experimental/chats/model-configs` endpoint
requires `policy.ActionRead` on `rbac.ResourceDeploymentConfig`, which
is only granted to admins.

The flow:
1. `AgentsPage.tsx` calls `useQuery(chatModelConfigs())` → hits
`listChatModelConfigs`
2. Non-admin users get a **403 Forbidden** response
3. `chatModelConfigsQuery.data` is `undefined`, so the
`modelConfigIDByModelID` map is empty
4. `handleCreateChat` falls back to `nilUUID` for `model_config_id`
5. The backend rejects the nil UUID: `"Invalid model config ID."`

## Fix

Changed `listChatModelConfigs` to allow all authenticated users to read
model configs:
- **Admin users** continue to see all configs (including disabled ones)
for management via `GetChatModelConfigs`
- **Non-admin users** now see only enabled configs via
`GetEnabledChatModelConfigs` with a system context, which is sufficient
for using the chat feature

This follows the same pattern as `listChatModels`, which already uses
`dbauthz.AsSystemRestricted(ctx)` to allow all authenticated users to
see available models.

Write endpoints (create/update/delete) retain their existing
`ResourceDeploymentConfig` authorization.

## Testing

- Updated `TestListChatModelConfigs/ForbiddenForOrganizationMember` →
`SuccessForOrganizationMember` to verify non-admin users can list
enabled model configs
- All existing chat tests continue to pass
2026-02-27 15:31:04 -05:00
Kyle Carberry f509c841cf fix(chatd): recover stale chats after coderd redeployment (#22405)
## Problem

When coderd instances are redeployed (e.g. rolling deployment on
dogfood), in-flight chats get stuck in `running` status permanently. The
UI shows them as "thinking" with a spinning indicator, but no worker is
actually processing them. They never error or resume.

## Root Cause

Two bugs combine to cause this:

### Bug 1: Shutdown cleanup uses a canceled context

The `processChat` defer block updates the chat status in the DB when
processing completes. But it uses `ctx`, which `Close()` cancels
*before* the defer runs. The DB transaction silently fails with
`context.Canceled`, leaving the chat in `status=running` with a dead
`worker_id`.

```go
// Close() calls p.cancel() which cancels ctx
// Then the defer tries to use the now-canceled ctx:
defer func() {
    err := p.db.InTx(func(tx database.Store) error {
        tx.GetChatByIDForUpdate(ctx, chat.ID) // FAILS
        tx.UpdateChatStatus(ctx, ...)          // FAILS
    }, nil)
}()
```

### Bug 2: Stale recovery runs only once at startup

`recoverStaleChats()` was called only once in `start()`, not
periodically. During a rolling deployment, the new instance starts while
the old one is still alive (fresh heartbeat). By the time the old
instance crashes, no one checks again.

## Fix

1. **Use `context.WithoutCancel(ctx)` in the processChat defer** — the
cleanup transaction now completes even during graceful shutdown.

2. **Run `recoverStaleChats` periodically** — a second ticker in the
`start()` loop checks for stale chats at `inFlightChatStaleAfter / 5`
intervals (default: every 1 minute). This catches orphaned chats even
when the instance that owns them crashes without clean shutdown.

## Tests

- `TestRecoverStaleChatsPeriodically` — Verifies chats orphaned *after*
startup are recovered by the periodic loop (not just the startup check).
- `TestNewReplicaRecoversStaleChatFromDeadReplica` — Verifies a new
replica recovers stale chats on startup.
- `TestWaitingChatsAreNotRecoveredAsStale` — Negative test: `waiting`
chats are not incorrectly modified by recovery.
2026-02-27 15:25:40 -05:00
Danielle Maywood 7e8559aac0 fix: respect light/dark theme in agents FilesChangedPanel diff viewer (#22403)
## Problem

The git diff on the `/agents` page had color issues: the editor
background followed light mode but the syntax highlighting used dark
mode (`github-dark-high-contrast`), and the filename header used
light-colored text on a light background.

The root cause was hardcoded dark theme options in the `FileDiff`
component:

```tsx
themeType: "dark",
theme: "github-dark-high-contrast",
```

## Fix

Uses the same theme-aware pattern as every other diff/file viewer in the
codebase (`WriteFileTool`, `EditFilesTool`, `ReadFileTool`, `Tool`,
`response.tsx`):

1. `useTheme()` from `@emotion/react` to read `palette.mode`
2. `getDiffViewerOptions(isDark)` from the shared `utils.ts` module —
returns `github-light` theme for light mode, `github-dark-high-contrast`
for dark mode
3. Reuses `DIFFS_FONT_STYLE` and `diffViewerCSS` constants instead of
inlining duplicates

## Storybook coverage

Added four new stories with real unified diff content:
- **WithDiffDark** — dark mode with a PR link
- **WithDiffLight** — light mode with a PR link
- **NoPullRequestDark** — dark mode, "Files Changed" header
- **NoPullRequestLight** — light mode, "Files Changed" header

The existing stories only covered empty and parse-error states with no
rendered diff.
2026-02-27 20:13:41 +00:00
Kyle Carberry b65c0766d2 feat: add line-based read_file tool with safety limits (#22400)
## Summary

Adds a new line-based file reading endpoint to the workspace agent,
replacing the unbounded byte-based approach for the `read_file` chat
tool and `coder_workspace_read_file` MCP tool.

**Problem**: The current `read_file` tool returns the entire file
contents with no limits, which can blow up LLM context windows and cause
OOM issues with large files.

**Solution**: Inspired by [`coder/mux`](https://github.com/coder/mux)
and [`openai/codex`](https://github.com/openai/codex), implement a
line-based reader with safety limits.

## Changes

### Agent (`agent/agentfiles/`)
- New `/read-file-lines` endpoint with `HandleReadFileLines` handler
- Line-based `offset` (1-based line number, default: 1) and `limit`
(line count, default: 2000)
- Safety constants:
  | Constant | Value | Purpose |
  |---|---|---|
  | `MaxFileSize` | 1 MB | Reject files larger than this at stat |
| `MaxLineBytes` | 1,024 | Per-line truncation with `... [truncated]`
marker |
  | `MaxResponseLines` | 2,000 | Max lines per response |
  | `MaxResponseBytes` | 32 KB | Max total response size |
  | `DefaultLineLimit` | 2,000 | Default when no limit specified |
- Line numbering format: `1\tcontent` (tab-separated)
- Structured JSON response: `{ success, file_size, total_lines,
lines_read, content, error }`
- Hard errors when limits exceeded — tells the LLM to use
`offset`/`limit`
- Existing byte-based `/read-file` endpoint preserved (used by
`instruction.go`)

### SDK (`codersdk/workspacesdk/`)
- `ReadFileLinesResponse` type added
- `ReadFileLines` method added to `AgentConn` interface
- Mock regenerated

### Chat tool (`coderd/chatd/chattool/`)
- `read_file` tool now uses `conn.ReadFileLines()` instead of
`conn.ReadFile()`
- Updated tool description to document line-based parameters
- Response includes `file_size`, `total_lines`, `lines_read` metadata

### MCP tool (`codersdk/toolsdk/`)
- `coder_workspace_read_file` updated to use line-based reading
- Schema descriptions updated for line-based offset/limit
- Removed `maxFileLimit` constant (agent handles limits now)

### Tests
- 13 new test cases for `TestReadFileLines`:
- Path validation (empty, relative, non-existent, directory, no
permissions)
  - Empty file handling
  - Basic read, offset, limit, offset+limit combinations
  - Offset beyond file length
  - Long line truncation (>1024 bytes)
  - Large file rejection (>1MB)
- All existing tests pass unchanged

## Design decisions

| Decision | Rationale |
|---|---|
| Line-based, not byte-based | Both coder/mux and openai/codex use
line-based — matches how LLMs reason about code |
| Default limit of 2000 | Matches codex; prevents accidental full-file
dumps while being generous |
| 32 KB response cap | Compromise between mux (16 KB) and codex (no cap)
|
| 1024 byte/line truncation with marker | More generous than codex
(500), marker helps LLM know data is missing |
| Hard errors on overflow | Matches mux; forces LLM to paginate rather
than getting partial data |
| Preserve byte-based endpoint | `instruction.go` needs raw byte access
for AGENTS.md |
2026-02-27 15:12:56 -05:00
Kyle Carberry ff687aa780 fix: re-read chat before publishing status event to preserve AI title (#22402)
## Problem

Chat titles revert to the fallback truncated title after briefly showing
the AI-generated title. Even reloading the page doesn't help — the
correct title flashes then gets overwritten.

## Root Cause

Single bug, two symptoms.

In `processChat` (`coderd/chatd/chatd.go`), the `chat` variable is
passed by value. The flow:

1. `processChat(ctx, chat)` receives `chat` with the initial fallback
title (truncated first message).
2. Inside `runChat`, `maybeGenerateChatTitle` generates an AI title,
writes it to the DB via `UpdateChatByID`, and publishes a `title_change`
event. **The DB has the correct title.** The client briefly displays it.
3. `runChat` returns. The **deferred cleanup** in `processChat`
publishes `publishChatPubsubEvent(chat, StatusChange)` — but `chat` here
is the original value copy that still has the **old fallback title**.
4. The frontend receives the `status_change` SSE event and
**unconditionally applies `title` from every event kind** (see
`AgentsPage.tsx` line ~305: `title: updatedChat.title`). This overwrites
the correct AI title with the stale fallback.

**Why reload doesn't help:** If the chat is still processing when the
page reloads, `listChats` loads the correct title from the DB, but then
the deferred `status_change` event arrives moments later and clobbers
it. The title was always in the DB — it was the pubsub event that kept
overwriting it.

## Fix

Re-read the chat from the database in the deferred cleanup before
publishing the final `status_change` event, so it carries the current
(AI-generated) title.
2026-02-27 15:06:36 -05:00
Kyle Carberry d4cfb24a4a fix: update document title when switching agents on the agents page (#22401)
When navigating to a specific agent on the Agents page, the browser tab
title now reflects the agent's chat title (e.g. `Fix login bug - Agents
- Coder`). When the title hasn't loaded yet or when navigating away, it
falls back to `Agents - Coder`.

**Changes:**
- Added a `useEffect` in `AgentDetail` that sets `document.title` via
the existing `pageTitle` utility whenever the chat title changes.
- The cleanup function resets the title back to `Agents - Coder` when
unmounting (navigating away from the agent).
2026-02-27 14:26:28 -05:00
Kyle Carberry 344d11fa22 feat: include OS and working directory in workspace agent prompt injection (#22399)
When injecting system instructions into the chat prompt, include:

1. **Operating system** and **working directory** from the
`workspace_agents` table
2. **Home-level instructions** from `~/.coder/AGENTS.md` (existing
behavior)
3. **Project-level instructions** from `<pwd>/AGENTS.md` (new)

The XML tag is renamed from `<coder-home-instructions>` to
`<system-instructions>` since it now carries more than just the home
instruction file.

### Example output (both files present)

```xml
<system-instructions>
Operating System: linux
Working Directory: /home/coder/coder

Source: /home/coder/.coder/AGENTS.md
... home instructions ...

Source: /home/coder/coder/AGENTS.md
... project instructions ...
</system-instructions>
```

### Example output (no AGENTS.md files)

```xml
<system-instructions>
Operating System: linux
Working Directory: /home/coder/coder
</system-instructions>
```

### Changes

- **`coderd/chatd/instruction.go`**:
- Renamed types: `homeInstructionContext` → `agentContext`, added
`instructionFile` struct
  - Extracted `readInstructionFileAtPath` shared helper
- Added `readWorkingDirectoryInstructionFile` to read `<pwd>/AGENTS.md`
- Replaced `formatHomeInstruction` with `formatInstructions` that
renders both files under `<system-instructions>`
- **`coderd/chatd/chatd.go`**:
- Renamed `resolveHomeInstruction` → `resolveInstructions`; now reads
both home and pwd instruction files
- `resolveAgentContext` returns `agentContext` (renamed from
`homeInstructionContext`)
- pwd file read is skipped gracefully if directory is empty or file
doesn't exist
- **`coderd/chatd/instruction_test.go`**:
- Added `TestReadWorkingDirectoryInstructionFile` (success, not-found,
empty-directory)
- Replaced `TestFormatHomeInstruction` with `TestFormatInstructions`
covering all combinations
- Added ordering test (`AgentContextBeforeFiles`) to verify OS/pwd
appear before file sources
2026-02-27 14:21:23 -05:00
Kyle Carberry 59cec5be65 feat: add pagination and popularity sorting to chattool list_templates (#22398)
## Summary

The `chattool` `list_templates` tool previously returned all templates
in a single response with no popularity signal. On deployments with many
templates (e.g. 71 on dogfood), this wastes tokens and makes it hard for
the AI to pick the right template for broad user questions.

## Changes

Single file: `coderd/chatd/chattool/listtemplates.go`

- **`page` parameter** — optional, 1-indexed, 10 results per page
- **Popularity sort** — queries
`GetWorkspaceUniqueOwnerCountByTemplateIDs` to get active developer
counts, then sorts descending (most popular first). The DB query returns
templates alphabetically, so this explicit sort is needed.
- **`active_developers`** — included on each template item so the agent
can see the signal
- **Pagination metadata** — `page`, `total_pages`, `total_count` in the
response so the agent knows there are more results
- **Updated tool description** — tells the agent that results are
ordered by popularity and paginated

## Frontend

No frontend changes needed. The renderer already reads `rec.templates`
and `rec.count` from the response — the new fields (`page`,
`total_pages`, `total_count`) are additive and safely ignored.
2026-02-27 14:06:22 -05:00
Kyle Carberry 7043e773cf fix: auto-scroll to bottom when switching chats on agents page (#22397)
When switching between chats on the agents page, the scroll position was
preserved from the previous chat instead of resetting to show the most
recent messages.

## Problem
Clicking a different chat in the sidebar loaded the new chat's messages
but kept the scroll container at whatever position the user had scrolled
to in the previous chat. This meant users often landed in the middle of
a conversation instead of at the bottom where the latest messages are.

## Fix
Added a `useEffect` in `AgentDetail` that resets `scrollTop` to `0`
whenever `agentId` changes. The scroll container uses
`flex-col-reverse`, so `scrollTop = 0` corresponds to the bottom (most
recent messages).
2026-02-27 14:00:50 -05:00
Cian Johnston 0cfa03718e fix(stringutil): operate on runes instead of bytes in Truncate (#22388)
Fixes https://github.com/coder/coder/issues/22375

Updates `stringutil.Truncate` to properly handle multi-byte UTF-8
characters.
Adds tests for multi-byte truncation with word boundary.

Created by Mux using Opus 4.6
2026-02-27 17:46:37 +00:00
Kyle Carberry 0252205374 agents: do not use bridge config vars for models (#22392)
<!--

If you have used AI to produce some or all of this PR, please ensure you
have read our [AI Contribution
guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING)
before submitting.

-->
2026-02-27 12:24:38 -05:00
Michael Suchacz 6248520130 chore(dogfood): update Rust from 1.86.0 to 1.93.1 (#22344)
Update the Rust Docker image in the dogfood template from 1.86.0 to
1.93.1, including the pinned `rust:slim` digest.
2026-02-27 18:02:40 +01:00
Kyle Carberry edee917d88 feat: add experimental agents support (#22290)
feat: add AI chat system with agent tools and chat UI

Introduce the chatd subsystem and Agents UI for AI-powered chat
within Coder workspaces.

- Add chatd package with chat loop, message compaction, prompt
  management, and LLM provider integration (OpenAI, Anthropic)
- Add agent tools: create workspace, list/read templates, read/write/
  edit files, execute commands
- Add chat API endpoints with streaming, message editing, and
  durable reconnection
- Add database schema and migrations for chats, chat messages, chat
  providers, and chat model configs
- Add RBAC policies and dbauthz enforcement for chat resources
- Add Agents UI pages with conversation timeline, queued messages
  list, diff viewer, and model configuration panel
- Add comprehensive test coverage including coderd integration tests,
  chatd unit tests, and Storybook stories
- Gate feature behind experiments flag

---------

Co-authored-by: Cian Johnston <cian@coder.com>
Co-authored-by: Danielle Maywood <danielle@themaywoods.com>
Co-authored-by: Jeremy Ruppel <jeremy@coder.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-27 16:50:56 +00:00
Cian Johnston 67da4e8b56 ci: add temporary deploy override (#22378)
Temporary override for deploying `main` to `dev.coder.com`.
2026-02-27 16:32:10 +00:00
Jake Howell bcb5b43aa7 fix: resolve entitlement check on ai bridge settings/view (#22385)
Resolves cases where the user is entitled to AI Governance but we don't
show them the page because its not enabled. If for some reason the user
doesn't have AI Bridge enabled anymore but still wants to access the old
logs page they now can.

Furthermore, we link to the docs regardless of if they have AI Bridge
enabled, this is inline with our other settings pages.
2026-02-28 03:25:04 +11:00
Jake Howell 6f3385d5e4 fix: resolve stability in scrollbar-gutter: stable (#22387)
Replaces the approach in #22061 (with a cleaner `git history`)

This now ensures that we don't attempt to cause a layout shift when the
sidebars pop-in-out of existence (when scroll locking within `radix`).
2026-02-28 03:20:43 +11:00
Paweł Banaszewski 6c097797a1 feat: add Mux icon to client column in AI Bridge request log page (#22386)
Adds Mux to the recognized clients list in AI Bridge documentation.

Adds Mux icon to AI Bridge requests log page:
<img width="1886" height="848" alt="image"
src="https://github.com/user-attachments/assets/e7cb8d47-595c-4be3-93c9-00dbea3d1153"
/>
2026-02-27 16:13:39 +00:00
Jake Howell 12372c4b1e fix: restore emptystate to <OrganizationProvisionerKeysPageView /> (#22372)
This element was receiving the provisioner key daemons and then
immediately filtering them. This lead to the default state being a table
with nothing rendered rather than the `<TableEmpty />` as we would
expect.

<img width="1133" height="608" alt="image"
src="https://github.com/user-attachments/assets/229edb00-b108-4ec3-ac2f-33633c3e5760"
/>
2026-02-28 03:11:57 +11:00
Steven Masley 21bc185254 doc: add language to mention disruptive nature of cookie host prefix (#22384) 2026-02-27 15:59:01 +00:00
Jake Howell 0bafc05c37 feat: add permission check around <AppearanceSettingsPage /> (#22383)
This previously let auditors view the page though they can't update
anything. In a different fashion to #22382 the user will be able to see
all of this as they're logged in to the application anyway, we can
simply tell them `Sorry, no access`.
2026-02-28 02:53:29 +11:00
Zach 2b9baffdcb chore: update setup-go action to fix Go download failures (#22306)
setup-go has been sporadically failing to download Go, and we were advised
by a member of the Go team that downloading Go from `storage.googleapis.com`
is not guaranteed (which is what setup-go <= v5.6.0 does).

Also remove the use-preinstalled-go optimization for Windows runners.
setup-go v6 sets GOTOOLCHAIN=local, which prevents the pre-installed
Go from auto-downloading the toolchain specified in go.mod. The windows
optimization with v5 relied on GOTOOLCHAIN=auto. setup-go uses the runner
cache, which is a different caching path but should serve the same purpose.
2026-02-27 08:43:53 -07:00
Jake Howell 358f521bbb feat: implement error message on failure to popup (#22380)
This change adds user-facing feedback when opening apps in a new window
fails due to popup blocking, replacing a silent no-op with a clear
recovery message. It improves reliability and supportability across
app-launch flows by helping users immediately understand and fix the
issue.
2026-02-28 02:41:44 +11:00
Paweł Banaszewski 2b0535b83f feat: update aibride library to 1.0.7 (#22381)
Updates aibridge library to `v1.0.7`
Includes Mux client identification:
https://github.com/coder/aibridge/pull/194
Fixes: https://github.com/coder/aibridge/issues/193
2026-02-27 16:35:46 +01:00
Jake Howell 39093dbd61 feat: implement invalidateQueries instead of location.reload() (#22377)
This was a poor UX decision to have to reload the entire page when a
template got invalidated. Simply now we refetch the data so that things
come across way smooother.
2026-02-28 02:02:07 +11:00
Jake Howell 7a3a228377 feat: implement toast on failure to export template (#22376)
This pull-request adds a small toast (fixing a previous `TODO`) when we
aren't able to export a template as we'd hoped.
2026-02-28 01:57:25 +11:00
Susana Ferreira ca234f346d fix: mark presets as validation_failed to prevent endless prebuild retries (#22085)
## Description

- Updates `wsbuilder` to return a `BuildError` with
`http.StatusBadRequest` to signify a "validation error" on missing or
invalid parameters
- Adds a short-circuit in `prebuilds.StoreReconciler` to mark presets
for which creating a build returns a "validation error" as "validation
failed" and skip further attempts to reconcile.
- Adds a test to verify the above
- Introduces a new Prometheus metric
`coderd_prebuilt_workspaces_preset_validation_failed` to track the above

Closes: https://github.com/coder/coder/issues/21237

---------

Co-authored-by: Cian Johnston <cian@coder.com>
2026-02-27 14:26:48 +00:00
Jeremy Ruppel dea451de41 fix(site): await dialog close after publish to prevent act() warnings (#22334)
State updates from setIsPublishingDialogOpen,
setLastSuccessfulPublishedVersion, and navigation were firing after
waitFor resolved, causing sporadic act() warnings and timeouts in the
publish template version tests (or so says Claude Sonnet 4.6).

Fixes coder/internal#1369

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-27 08:59:22 -05:00
Jake Howell 173299fcec fix: restore dividing line in <BuildParametersPopover /> (#22373)
Related to #22367 

It was pointed out to me that we actually did regress this mildly by
removing a dividing line in the changes made in #22367, I've restored
this in a better way by taking advantage of `divide-y` and wrapping this
in a proper `<div />`.

<img width="332" height="385" alt="image"
src="https://github.com/user-attachments/assets/2827a9ae-7b54-4c48-aae9-2f6e965e7f8b"
/>
2026-02-28 00:52:55 +11:00
Jake Howell 6364cfa360 fix: resolve double border on <WorkspaceBuildPageView /> (#22374)
Not sure how I didn't see this before in #22362 , theres a sneaky double
border which makes things slightly thicker than we would expect. This
border should only be showing with the `border-b` and we forget to reset
down to `border-0` first.

<img width="1253" height="109" alt="image"
src="https://github.com/user-attachments/assets/4b31f5bd-e48b-48d8-a0a3-abeac3d6720b"
/>
<img width="1244" height="199" alt="image"
src="https://github.com/user-attachments/assets/01c2567e-e723-47ea-a5cc-ed8e025df5d0"
/>
2026-02-27 13:51:28 +00:00
Jeremy Ruppel e161083053 fix(site): use cross-browser compatible assertions in MonacoEditor story (#22337)
Switch to asserting only on the onChange spy, which is the actual
component contract being tested. Monaco's textarea value is always empty
regardless of model content, so the toHaveValue assertions were
unreliable anyway.

Fixes the new storybook test introduced in #22202
2026-02-27 08:40:23 -05:00
Jake Howell a51eb40dca fix: marshal convertLicenses() into a [] instead of nil (#22366)
This was a bad smell that was being addressed by the frontend. This type
was generating out to be a `nil`/`null` instead of an empty `License[]`.
Now this returns as an empty array and we can actively check if we have
no licenses with a length of `0`.
2026-02-28 00:23:41 +11:00
Jake Howell d204e6fb84 feat: implement file icons in <TemplateFiles /> (#22369)
This pull-request takes our icons shown in the sidebar tree and shows
them alongside the names of the files in the `Source Code` page of our
templates.

Also does a quick de-mui of this page.

<img width="637" height="345" alt="image"
src="https://github.com/user-attachments/assets/f3013eb6-9572-4d05-a683-10bb99b4e802"
/>
2026-02-27 22:53:09 +11:00
Jake Howell c6638f5547 feat: remove mui from <AppearanceSettingsPage /> (#22368)
This migrates the content from using `MUI` to the new standard `shadcn`
components on the `<AppearanceSettingsPage />`. Things should work
exactly how they did before, but with a new shinier coat of paint.

| Old | New |
| --- | --- |
| <img width="1019" height="651" alt="APPEARANCE_SETTINGS_OLD"
src="https://github.com/user-attachments/assets/d514ea17-f669-4343-99c1-9c8896ae85d8"
/> | <img width="1019" height="653" alt="APPEARANCE_SETTINGS_NEW"
src="https://github.com/user-attachments/assets/424616af-c975-43fa-bd4a-13d0b0fe136b"
/> |
2026-02-27 22:29:40 +11:00
Jake Howell fb154cb60a fix: center view raw logs button in <WorkspaceBuildPage /> (#22362)
This pull-request resolves the location of `View Raw Logs` in
`<WorkspaceBuildPage />`. It wasn't properly centred before due to some
odd-padding.

| Old | New |
| --- | --- |
| <img width="304" height="210" alt="OLD_VIEW_RAW_LOGS"
src="https://github.com/user-attachments/assets/80a5aa61-8d01-48eb-91c0-df61dd59d1fb"
/> | <img width="310" height="210" alt="NEW_VIEW_RAW_LOGS"
src="https://github.com/user-attachments/assets/8649a478-993d-4cd2-98bb-f503e0f22a5c"
/> |
2026-02-27 11:27:09 +00:00
Jake Howell 900f6ef576 fix: remove double border <BuildParametersPopover /> (#22367)
This `<Popover />` had a double border on the bottom for some reason, I
believe there used to be multiple things within this previously but
there is no longer. All this does is simply tighten things up
ever-so-slightly.

| Old | New |
| --- | --- |
| <img width="321" height="209" alt="OLD_BUILD_OPTIONS"
src="https://github.com/user-attachments/assets/9eb3322d-5ac1-4d74-ab2a-39e963eae668"
/> | <img width="321" height="209" alt="NEW_BUILD_OPTIONS"
src="https://github.com/user-attachments/assets/f2200003-0fd5-4c90-8660-5424c3cf1807"
/> |
2026-02-27 22:14:55 +11:00
Jake Howell 9cce241202 fix: migrate <TemplatePermissionsPageView /> out of mui (#22363)
This pull-request migrates `<TemplatePermissionsPageView />` out of
using MUI components. Using proper selects that are inline with the rest
of the codebase.

| Old | New |
| --- | --- |
| <img width="1030" height="284" alt="OLD_TEMPLATE_PERMISSIONS"
src="https://github.com/user-attachments/assets/778e983b-7ac1-4429-87ca-6107b176a762"
/> | <img width="1030" height="283" alt="NEW_TEMPLATE_PERMISSIONS"
src="https://github.com/user-attachments/assets/f7acf3c7-0cbd-4433-adc1-7ba7f44f3fe2"
/> |
2026-02-27 22:14:37 +11:00
Atif Ali b9fd9bc0ca chore(dogfood): set OPENAI_BASE_URL and OPENAI_API_KEY if aibridge is enabled (#22364) 2026-02-27 16:00:16 +05:00
Jake Howell dbc0daa64b feat: add animation to copy button <Check />s (#22319)
This pull-request adds a tiny little animation for the `<Check />` when
the copy button activates. This is inline with how `sonner` does their
animations (means things are a little more uniform).


![preview-check-animation](https://github.com/user-attachments/assets/533f644a-1d86-4b11-8ca3-c670a9913b57)
2026-02-27 21:22:57 +11:00
Jake Howell 54a7ec4b5b fix: resolve borders on <WorkspaceProxyRow /> (#22345)
This pull-request resolves some stupid border issues we were having in
`<WorkspaceProxyRow />`. We were duplicating borders and this would lead
to a `2px` border (`1px + 1px` added together). I've resolved this by
using the `divide-y` class in Tailwind so that things look more uniform.

| Old | New |
| --- | --- |
| <img width="1139" height="449" alt="OLD_WORKSPACE_PROXY"
src="https://github.com/user-attachments/assets/809926e4-ac53-4244-b215-72408ab6fa51"
/> | <img width="1139" height="449" alt="NEW_WORKSPACE_PROXY"
src="https://github.com/user-attachments/assets/b451b79c-6275-4dd5-ad2d-0ac6472b53bb"
/> |
2026-02-27 21:04:54 +11:00
Michael Suchacz 5194cc8050 chore(dogfood): bump Mux to 1.3.0, use next channel and bun runner (#22357)
Bump the Mux module in the dogfood template:

- Version: 1.1.0 → 1.3.0
- `install_version`: set to `"next"`
- `runner`: set to `"bun"`
2026-02-27 10:41:02 +01:00
blinkagent[bot] 24ab5205d2 docs: add AI Bridge structured logging section to setup page (#22361)
Adds a brief "Structured Logging" section to the [AI Bridge
Setup](https://coder.com/docs/ai-coder/ai-bridge/setup) page documenting
the `--aibridge-structured-logging` /
`CODER_AIBRIDGE_STRUCTURED_LOGGING` flag.

Covers:
- How to enable structured logging (CLI flag, env var, YAML)
- The five `record_type` values emitted (`interception_start`,
`interception_end`, `token_usage`, `prompt_usage`, `tool_usage`) and
their key fields
- How to filter for these records in a logging pipeline

Created on behalf of @dannykopping

---------

Co-authored-by: blink-so[bot] <211532188+blink-so[bot]@users.noreply.github.com>
2026-02-27 10:40:59 +01:00
Kacper Sawicki ab28ecde88 fix(cli): reuse multi-select parameter values on workspace update (#22261)
Fixes three bugs that caused `coder update` to always re-prompt for
multi-select (`list(string)`) parameters instead of reusing previous
build values:

1. **`isValidTemplateParameterOption` failed for multi-select values**
(`cli/parameterresolver.go`): It compared the entire JSON array string
(e.g. `["vim","emacs"]`) against individual option values, which never
matched. Now parses the JSON array and validates each element
separately.

2. **`RichParameter` ignored previous build value for multi-select**
(`cli/cliui/parameter.go`): The `list(string)` branch always used the
template's default value instead of the `defaultValue` argument (which
carries the previous build's value). Now uses `defaultValue` when
available, falling back to the template default.

3. **Pre-existing crash when `list(string)` has no default value**
(`cli/cliui/parameter.go`): `json.Unmarshal` on an empty string caused
`unexpected end of JSON input`. Now skips unmarshaling when the default
source is empty.

Fixes #19956
2026-02-26 14:34:30 +01:00
Dean Sheather bef7eb9dcc fix: avoid derp-related panic during wsproxy registration (#22322) 2026-02-27 00:07:14 +11:00
Ehab Younes bf639d0016 refactor(site): use dedicated task pause/resume API endpoints (#22303)
Switch from workspace stop/start operations to the dedicated tasks pause and resume endpoints for cleaner semantics.
2026-02-26 10:46:20 +01:00
dependabot[bot] 83f2bb15c8 chore: bump github.com/cloudflare/circl from 1.6.1 to 1.6.3 (#22312)
Bumps [github.com/cloudflare/circl](https://github.com/cloudflare/circl)
from 1.6.1 to 1.6.3.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/cloudflare/circl/releases">github.com/cloudflare/circl's
releases</a>.</em></p>
<blockquote>
<h2>CIRCL v1.6.3</h2>
<p>Fix a bug on ecc/p384 scalar multiplication.</p>
<h3>What's Changed</h3>
<ul>
<li>sign/mldsa: Check opts for nil value by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/582">cloudflare/circl#582</a></li>
<li>ecc/p384: Point addition must handle point doubling case. by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/583">cloudflare/circl#583</a></li>
<li>Release CIRCL v1.6.3 by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/584">cloudflare/circl#584</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/cloudflare/circl/compare/v1.6.2...v1.6.3">https://github.com/cloudflare/circl/compare/v1.6.2...v1.6.3</a></p>
<h2>CIRCL v1.6.2</h2>
<ul>
<li>New SLH-DSA, improvements in ML-DSA for arm64.</li>
<li>Tested compilation on WASM.</li>
</ul>
<h2>What's Changed</h2>
<ul>
<li>Optimize pairing product computation by moving exponentiations to
G1. by <a href="https://github.com/dfaranha"><code>@​dfaranha</code></a>
in <a
href="https://redirect.github.com/cloudflare/circl/pull/547">cloudflare/circl#547</a></li>
<li>sign: Adding SLH-DSA signature by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/512">cloudflare/circl#512</a></li>
<li>Update code generators to CIRCL v1.6.1. by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/548">cloudflare/circl#548</a></li>
<li>ML-DSA: Add preliminary Wycheproof test vectors by <a
href="https://github.com/bwesterb"><code>@​bwesterb</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/552">cloudflare/circl#552</a></li>
<li>go fmt by <a
href="https://github.com/bwesterb"><code>@​bwesterb</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/554">cloudflare/circl#554</a></li>
<li>gz-compressing test vectors, use of HexBytes and ReadGzip functions.
by <a href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/555">cloudflare/circl#555</a></li>
<li>group: Removes use of elliptic Marshal and Unmarshal functions. by
<a href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/556">cloudflare/circl#556</a></li>
<li>Support encoding/decoding ML-DSA private keys (as long as they
contain seeds) by <a
href="https://github.com/bwesterb"><code>@​bwesterb</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/559">cloudflare/circl#559</a></li>
<li>Update to golangci-lint v2 by <a
href="https://github.com/bwesterb"><code>@​bwesterb</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/560">cloudflare/circl#560</a></li>
<li>Preparation for ARM64 Implementation of poly operations for
dilithium package. by <a
href="https://github.com/elementrics"><code>@​elementrics</code></a> in
<a
href="https://redirect.github.com/cloudflare/circl/pull/562">cloudflare/circl#562</a></li>
<li>prepare power2Round for custom implementations in assembly by <a
href="https://github.com/elementrics"><code>@​elementrics</code></a> in
<a
href="https://redirect.github.com/cloudflare/circl/pull/564">cloudflare/circl#564</a></li>
<li>ARM64 implementation for poly.PackLe16 by <a
href="https://github.com/elementrics"><code>@​elementrics</code></a> in
<a
href="https://redirect.github.com/cloudflare/circl/pull/563">cloudflare/circl#563</a></li>
<li>add arm64 version of polyMulBy2toD by <a
href="https://github.com/elementrics"><code>@​elementrics</code></a> in
<a
href="https://redirect.github.com/cloudflare/circl/pull/565">cloudflare/circl#565</a></li>
<li>add arm64 version of polySub by <a
href="https://github.com/elementrics"><code>@​elementrics</code></a> in
<a
href="https://redirect.github.com/cloudflare/circl/pull/566">cloudflare/circl#566</a></li>
<li>group: add byteLen method for short groups and RandomScalar uses
rand.Int by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/568">cloudflare/circl#568</a></li>
<li>add arm64 version of poly.Add/Sub by <a
href="https://github.com/elementrics"><code>@​elementrics</code></a> in
<a
href="https://redirect.github.com/cloudflare/circl/pull/572">cloudflare/circl#572</a></li>
<li>group: Adding cryptobyte marshaling to scalars by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/569">cloudflare/circl#569</a></li>
<li>Bumping up to Go1.25 by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/574">cloudflare/circl#574</a></li>
<li>ci: Including WASM compilation. by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/577">cloudflare/circl#577</a></li>
<li>Revert to using package-declared HPKE errors for shortkem instead of
standard library errors by <a
href="https://github.com/harshiniwho"><code>@​harshiniwho</code></a> in
<a
href="https://redirect.github.com/cloudflare/circl/pull/578">cloudflare/circl#578</a></li>
<li>Release v1.6.2 by <a
href="https://github.com/armfazh"><code>@​armfazh</code></a> in <a
href="https://redirect.github.com/cloudflare/circl/pull/579">cloudflare/circl#579</a></li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a href="https://github.com/dfaranha"><code>@​dfaranha</code></a>
made their first contribution in <a
href="https://redirect.github.com/cloudflare/circl/pull/547">cloudflare/circl#547</a></li>
<li><a
href="https://github.com/elementrics"><code>@​elementrics</code></a>
made their first contribution in <a
href="https://redirect.github.com/cloudflare/circl/pull/562">cloudflare/circl#562</a></li>
<li><a
href="https://github.com/harshiniwho"><code>@​harshiniwho</code></a>
made their first contribution in <a
href="https://redirect.github.com/cloudflare/circl/pull/578">cloudflare/circl#578</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/cloudflare/circl/compare/v1.6.1...v1.6.2">https://github.com/cloudflare/circl/compare/v1.6.1...v1.6.2</a></p>
</blockquote>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/cloudflare/circl/commit/24ae53c5d6f7fe18203adc125ba3ed76a38703e1"><code>24ae53c</code></a>
Release CIRCL v1.6.3</li>
<li><a
href="https://github.com/cloudflare/circl/commit/581020bd4a836b8ce7bd4e414ba2884c07dbc906"><code>581020b</code></a>
Rename method to oddMultiplesProjective.</li>
<li><a
href="https://github.com/cloudflare/circl/commit/12209a4566605692a8402594e367a5aed5148460"><code>12209a4</code></a>
Removing unused cmov for jacobian points.</li>
<li><a
href="https://github.com/cloudflare/circl/commit/fcba359f4178645d2c9e50f29ab6966337da4b95"><code>fcba359</code></a>
ecc/p384: use of complete projective formulas for scalar
multiplication.</li>
<li><a
href="https://github.com/cloudflare/circl/commit/5e1bae8d8c2df4e717c2c5c2d5b5d60b629b2ac6"><code>5e1bae8</code></a>
ecc/p384: handle point doubling in point addition with Jacobian
coordinates.</li>
<li><a
href="https://github.com/cloudflare/circl/commit/341604685ff97e8f7440ae4b4711ba1c118c648c"><code>3416046</code></a>
Check opts for nil value.</li>
<li><a
href="https://github.com/cloudflare/circl/commit/a763d47a6dce43d1f4f7b697d1d7810463a526f6"><code>a763d47</code></a>
Release CIRCL v1.6.2</li>
<li><a
href="https://github.com/cloudflare/circl/commit/3c70bf9ad53b681fbe5ba6067e454a86549fee8a"><code>3c70bf9</code></a>
Bump x/crypto x/sys dependencies.</li>
<li><a
href="https://github.com/cloudflare/circl/commit/3f0f15b2bfe67bad81a35e8aec81ae42ca78349d"><code>3f0f15b</code></a>
Revert to using package-declared HPKE errors for shortkem instead of
standard...</li>
<li><a
href="https://github.com/cloudflare/circl/commit/23491bd573cf29b6f567057a158203a2c9dfa30d"><code>23491bd</code></a>
Adding generic Power2Round method.</li>
<li>Additional commits viewable in <a
href="https://github.com/cloudflare/circl/compare/v1.6.1...v1.6.3">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/cloudflare/circl&package-manager=go_modules&previous-version=1.6.1&new-version=1.6.3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
You can disable automated security fix PRs for this repo from the
[Security Alerts page](https://github.com/coder/coder/network/alerts).

</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-25 20:17:36 +00:00
Zach 4f1ddeeaad fix(site): scope TemplateSettingsPage validation error assertion to form (#22308)
The sonner migration (https://github.com/coder/coder/pull/22258) shows
validation errors in both the inline form field and a toast. Scoping the
assertion to the form element avoids flaky matches against the toast.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 13:15:03 -07:00
Jeremy Ruppel 77006f241b fix: save empty template files (#22202)
The Monaco editor wrapper was only calling `onChange` if the template
file has content, but we want to allow saving an empty file.

Fixes #19721

Claude was used to port tests from jest to vitest, and for the stories.

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Kayla はな <mckayla@hey.com>
2026-02-25 13:43:07 -05:00
Jaayden Halko 4e1cedf8fd chore: migrate workspace agent row styling to Tailwind (#22195) 2026-02-25 17:43:29 +00:00
Jon Ayers 4e365e59b6 fix: add provision/tags to prebuilds scenario (#22294) 2026-02-25 11:16:20 -06:00
blinkagent[bot] d140920248 fix(coderd): bump taskname default model from Claude 3.5 Haiku to Claude Haiku 4.5 (#22304)
Claude 3.5 Haiku (`claude-3-5-haiku-20241022`) was retired by Anthropic
on February 19th, 2026. Requests to this model now return errors.

Switch to Claude Haiku 4.5 (`claude-haiku-4-5`), which is the
[recommended
replacement](https://docs.anthropic.com/en/docs/resources/model-deprecations).

---

One-line change in `coderd/taskname/taskname.go` L25:
```diff
- defaultModel = anthropic.ModelClaude3_5HaikuLatest
+ defaultModel = anthropic.ModelClaudeHaiku4_5
```

Co-authored-by: blink-so[bot] <211532188+blink-so[bot]@users.noreply.github.com>
2026-02-25 16:38:04 +00:00
Steven Masley 3353e687e7 chore: use header auth over cookies for agents (#22226)
All non-browser connections should not use cookies
2026-02-25 09:53:41 -06:00
Zach 2bac4eb739 fix: use time.Equal() for external auth token expiry comparison (#22295)
The listen loop in workspaceAgentsExternalAuthListen compared
OAuthExpiry using == which compares `time.Time` internal struct fields
including the `*time.Location` pointer.

`time.LoadLocation` does not cache the returned `*Location` pointer, so
each lib/pq connection gets a distinct pointer for the same timezone.
When `pq.ParseTimestamp()` applies the connection's location to a parsed
timestamp, the resulting time.Time embeds that connection-specific
pointer. If the `sql.DB` pool hands out different connections for the
two GetExternalAuthLink reads, the identical timestamp produces
`time.Time` values where == returns false despite representing the same
instant. This is intermittent because the pool _usually_ reuses the same
connection for sequential queries.

This change uses `.Equal()` to compare instants regardless of location.
Also makes the test's validation call counter atomic to fix a possible
data race between the HTTP server and test goroutines.
2026-02-25 08:45:00 -07:00
Jake Howell 15a2bab1cd feat: migrate from <GlobalSnackbar /> to sonner (#22258)
Replaces our custom `<GlobalSnackbar />` (MUI Snackbar + event emitter)
with [`sonner`](https://github.com/emilkowalski/sonner). Deletes
`GlobalSnackbar/`, the custom event emitter infra, and migrates ~80
source files to `toast.success()` / `toast.error()` from `sonner`.

- ~47 error toasts now surface API error detail via
`getErrorDetail(error)` in the toast description, not just a generic
message. Coincides with #22229.
- Toast messages follow an `{Action} "{entity}" {result}.` format (e.g.
`User "alice" suspended successfully.`) since toasts persist across
navigation now.
- 17 uses of `toast.promise()` for loading → success → error lifecycle.
- Some toasts include action buttons for quick navigation (e.g. "View
task", "View template").
- Multiple toasts can stack and display simultaneously.

---------

Co-authored-by: Kayla はな <mckayla@hey.com>
2026-02-26 02:42:34 +11:00
Jake Howell 1c4d8fafc7 fix: move baseline css from mui to index.css (#22238)
This pull-request moves our baseline CSS styles from the MUI theme
(`site/src/theme/mui.ts`) definition to `index.css`. As these are global
styles they should live in one dedicated place not two.
2026-02-26 02:42:16 +11:00
Jake Howell b0b9ea6fbf fix: remove mui components from <LicenseCard /> (#22236)
This pull-request removes the `@mui/material/Paper` import from
`<LicenseCard />` so that we can nuke the `<Paper />` dependency
component. 🥳🥳🥳
2026-02-26 02:41:37 +11:00
Jake Howell 98587cfc03 fix: remove mui components from <TagInput/> (#22234)
This pull-request removes the last instance of `@mui/material/Chip` from
the codebase. And removes it from our `vite.config.mts` so we no longer
have to cache it 🙂
2026-02-26 02:41:22 +11:00
Jake Howell d2787df442 feat: add AI Bridge request logs model filter (#22230)
This pull-request implements a simple filtering logic so that we're able
to pick which model the user actually used when logs were sent to AI
Bridge.

- Add `GET /aibridge/models` API endpoint that returns distinct model
names from AI Bridge interceptions, with pagination and search support
- New `ListAIBridgeModels` SQL query using case-sensitive prefix
matching (`LIKE model || '%'`) to allow B-tree index usage
- Hand-written `ListAuthorizedAIBridgeModels` in `modelqueries.go` for
RBAC authorization filter injection
- `AIBridgeModels` search query parser in searchquery/search.go
(defaults bare terms to the `model` field)
- dbauthz wrappers, dbmetrics, and dbmock implementations for the new
query

<img width="292" height="185" alt="image"
src="https://github.com/user-attachments/assets/134771df-2d26-4c54-acc4-27f58128b351"
/>
2026-02-26 02:40:45 +11:00
Jake Howell 1dec1ec4ad fix: add getValidationErrorMessage() to getErrorDetail() (#22229)
This pull-request ensures that we render the validation errors back to
the user when the errors contain context as to why. Previously when we
attempted to we'd guide the user to simply checkout the Dev Console.
This wasn't a great approach as the user still would have to decode
this, the context is explicit now.

The error messages could use some improvement, but we make use of this
already in
[`Filter.tsx`](https://github.com/coder/coder/blob/main/site/src/components/Filter/Filter.tsx#L259)
so at-least its inline.

<img width="834" height="335" alt="image"
src="https://github.com/user-attachments/assets/78864d6f-b4df-4eeb-815a-3fd46cf9f31b"
/>

---------

Co-authored-by: Phorcys <57866459+phorcys420@users.noreply.github.com>
2026-02-26 02:35:25 +11:00
Mathias Fredriksson d2f33932c0 test(coderd): remove provisioner daemon from SendToNonActiveStates test (#22298)
This change a test flake triggered disabling the provisioner daemon that
was modifying jobs created by dbgen.

Fixes coder/internal#1367
2026-02-25 13:14:32 +02:00
Garrett Delfosse 4057363f78 fix(coderd): add organization_name label to insights Prometheus metrics (#22296)
## Description

When multiple organizations have templates with the same name, the
Prometheus `/metrics` endpoint returns HTTP 500 because Prometheus
rejects duplicate label combinations. The three `coderd_insights_*`
metrics (`coderd_insights_templates_active_users`,
`coderd_insights_applications_usage_seconds`,
`coderd_insights_parameters`) used only `template_name` as a
distinguishing label, so two templates named e.g. `"openstack-v1"` in
different orgs would produce duplicate metric series.

This adds `organization_name` as a label to all three insight metric
descriptors to disambiguate templates across organizations.

## Changes

**`coderd/prometheusmetrics/insights/metricscollector.go`**:
- Added `organization_name` label to all three metric descriptors
- Added `organizationNames` field (template ID → org name) to the
`insightsData` struct
- In `doTick`: after fetching templates, collect unique org IDs, fetch
organizations via `GetOrganizations`, and build a
template-ID-to-org-name mapping
- In `Collect()`: pass the organization name as an additional label
value in every `MustNewConstMetric` call

**`coderd/prometheusmetrics/insights/testdata/insights-metrics.json`**:
Updated golden file to include `organization_name=coder` in all metric
label keys.

Fixes #21748
2026-02-25 08:58:50 +00:00
Jon Ayers 43b8df86c1 fix: log WARN on ErrConnectionClosed in tailnet.Controller.Run (#22293) 2026-02-25 01:27:53 -06:00
Jon Ayers 4f34452bcc fix: use separate http.Transports for wsproxy tests (#22292)
- Previously all tests were sharing the global http.Transport meaning on
`Close` it would close connections presumed to be idle for other tests.
fixes https://github.com/coder/internal/issues/112
2026-02-24 23:56:58 -06:00
Steven Masley 93e823931b fix: allow sharing ports >9999 (#22273)
Closes https://github.com/coder/coder/issues/22267
2026-02-24 23:46:43 -06:00
Garrett Delfosse 6c16794173 fix(cli): proactively use active template version when require_active_version is set (#22033)
Fixes #22030

## Problem

When a template has `require_active_version = true` and a workspace is
outdated, the web UI always shows "Update and start" as the **only**
button (for all users including admins), but `coder start` starts with
the old version. For admins, this silently succeeds on the stale
version. For non-admins, it goes through a clunky 403→retry path. This
also affects the VS Code extension, which calls `coder start --yes`
under the hood.

## Root Cause

`buildWorkspaceStartRequest()` in `cli/start.go` checks
`workspace.AutomaticUpdates == "always"` but ignores
`workspace.TemplateRequireActiveVersion`. The server-side autostart
already ORs both settings together:

```go
// coderd/autobuild/lifecycle_executor.go
func useActiveVersion(opts, ws) bool {
    return opts.RequireActiveVersion || ws.AutomaticUpdates == "always"
}
```

The CLI was missing the `RequireActiveVersion` check.

## Fix

Add `workspace.TemplateRequireActiveVersion` to the existing OR
condition:

```go
// Before:
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {

// After:
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
```

Now `coder start` and `coder restart` proactively use the active
template version when `require_active_version` is set, matching the web
UI and server autostart behavior. The 403→retry fallback remains as a
safety net but is no longer the primary path for any user.

## Testing

Updated `enterprise/cli/start_test.go` — all user types (owner, template
admin, ACL admin, group ACL admin, member) now expect the active version
when `require_active_version` is set, and verify the 403→retry message
does NOT appear.
2026-02-24 19:51:48 -05:00
George K 119d436071 chore(docs): add app access section to workspace sharing docs (#22281)
Path-based routing (the default for, e.g., code-server) will prevent
access in workspace sharing scenarios. This commit documents the
workaround.

Closes: https://linear.app/codercom/issue/MAN-15/bug-code-server-throws-404-for-a-shared-workspace
2026-02-24 11:29:57 -08:00
Zach 9613e41d21 chore: update boundary version (#22289)
Updating to the latest tag before the 2.31 code freeze.
2026-02-24 13:33:37 -05:00
Mathias Fredriksson 947b390c5a fix: allow agent-reported final states, add SSE reconnection (#22286)
When AgentAPI is configured, `WithTaskReporter` unconditionally
overrides all self-reported states to `working`. The intent was to
distrust the agent's `idle` and rely on the screen watcher, but the
override also blocks `failure` and `complete`, which only the agent can
produce (the screen watcher only knows `running`/`stable`). Tasks get
stuck as `working` or `null` forever.

Now only `idle` is overridden to `working`; `failure`, `complete`, and
`working` pass through as-is.

Also:

- Remove misplaced unconditional `"Failed to watch screen events"` log
that fired on every startup
- Add SSE reconnection with exponential backoff (1s-30s) in
`startWatcher` so it recovers from dropped connections instead of dying
silently
- Add `complete` to the `coder_report_task` tool enum, which the
`coder/claude-code` registry module already instructs agents to use but
was missing from the schema

Refs coder/internal#1350
2026-02-24 20:28:50 +02:00
Cian Johnston 6336fee3a7 feat: add telemetry for task lifecycle events (#21922)
Relates to https://github.com/coder/internal/issues/1259

Adds new database queries and telemetry collection functions to gather
task lifecycle events (pause/resume cycles, idle time) for analytics.
    
Task events track pause/resume activity, idle duration before pausing,
paused duration, and time from resume to first app status, filtered to
recent activity based on the telemetry snapshot interval.

🤖 Created with Mux (Opus 4.6).
2026-02-24 17:04:42 +00:00
Danielle Maywood 974ca3eda6 fix: use "idle timeout" as task auto-pause reason (#22287) 2026-02-24 16:45:56 +00:00
Sushant P 20797347b4 chore: update shared workspaces beta docs to include some screenshots (#22280)
Updating the docs to include some screenshots before Shared Workspace
goes into beta!
2026-02-24 08:28:22 -08:00
Jake Howell adcdbfd562 feat: implement AI Bridge client table column (#22228)
Closes #22144

Add client information column to AI Bridge request logs, showing which
coding tool initiated each request with matching icons.

- Added `Client` column to request logs table header and row, displaying
client name with icon badge
- Created `AIBridgeClientIcon` component mapping backend client
constants to their icons (Claude Code, Codex, Kilo Code, Roo Code, Zed,
Cursor, GitHub Copilot)
([ref.](https://github.com/coder/aibridge/blob/11fe0799402a652743104d047140fbeb28f02d24/bridge.go#L33-L41))
- Moved `AIBridgeModelIcon` and `AIBridgeProviderIcon` into `icons/`
subdirectory and clean up `props.className` → `className` prop
- Added new static icons: `github-copilot.svg`, `kilo-code.svg`,
`roo-code.svg` with entries in `icons.json` and `externalImages.ts`
- Sorted `externalImages.ts` map alphabetically

| Name | Preview |
| --- | --- |
| GitHub Copilot | <img width="1332" height="67" alt="image 11"
src="https://github.com/user-attachments/assets/0b06ea42-aaf9-431b-9f9f-3a0146d3eb44"
/> |
| Claude | <img width="1332" height="327" alt="PREVIEW_CLAUDE"
src="https://github.com/user-attachments/assets/7e1afcbc-b94b-4017-bbdc-f40e0ca237d8"
/> |
| Codex CLI | <img width="1332" height="67" alt="PREVIEW_CODEX"
src="https://github.com/user-attachments/assets/2a9ffde1-2516-4d81-baf0-6e689d8a37bf"
/> |
| Cursor | <img width="1332" height="67" alt="PREVIEW_CURSOR"
src="https://github.com/user-attachments/assets/2c4883e8-35cd-4b08-8463-82ba7c95d96d"
/> |
| KiloCode | <img width="1332" height="132" alt="PREVIEW_KILO_CODE"
src="https://github.com/user-attachments/assets/e8bc2854-6fdb-47e0-a304-fb138ac0e2fe"
/> |
| Roo Code | <img width="1332" height="262" alt="PREVIEW_ROO_CODE"
src="https://github.com/user-attachments/assets/d2977090-525b-44ee-9ab6-e6019e559bbd"
/> |
| Zed | <img width="1332" height="67" alt="PREVIEW_ZED"
src="https://github.com/user-attachments/assets/1d982ae0-1d08-4b85-8b4a-5c13fb7754f1"
/> |
2026-02-25 02:38:57 +11:00
Kacper Sawicki 1e274063d4 feat(coderd): filter expired API tokens server-side (#22263)
## Summary

Moves expired token filtering from client-side to server-side by adding
an `include_expired` parameter to the `GetAPIKeysByLoginType` and
`GetAPIKeysByUserID` database queries. This is more efficient for large
deployments with many expired/short-lived tokens.

## Changes

- Add `include_expired` parameter to SQL queries using `OR`
short-circuit
- Add `include_expired` query parameter to `GET
/users/{user}/keys/tokens`
- Add `IncludeExpired` field to `codersdk.TokensFilter`
- Remove client-side filtering from CLI `tokens list` command
- Add `TestTokensFilterExpired` test

Fixes coder/internal#1357
2026-02-24 15:27:03 +00:00
Spike Curtis 393b3874ac feat: add UpdateAppStatus to the workspace agent API (#22219)
<!--

If you have used AI to produce some or all of this PR, please ensure you
have read our [AI Contribution
guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING)
before submitting.

-->

part of https://github.com/coder/coder/issues/21335  
  
This moves updating app status (used by Tasks) into the workspace agent
API over dRPC. This will allow us to update the status without having to
re-authenticate each time, like we would with an HTTP PATCH request.
  
Further PRs in this stack will pipe these requests thru from the CLI MCP
server to the agentsock and finally to this dRPC call to coderd.
2026-02-24 13:26:55 +04:00
Kacper Sawicki 3c69d683f4 fix(cli): allow new immutable parameters via --parameter flag during update (#22221)
## Problem

When a template adds a new immutable parameter, `coder update
--parameter param=value` fails with:

```
error: start workspace: parameter "machine_type" is immutable and cannot be updated
```

The interactive prompt handles this correctly (allows setting first-time
immutable params), but the CLI `--parameter` flag path does not.

## Root Cause

In `cli/parameterresolver.go`, `verifyConstraints()` runs before the
interactive prompt and unconditionally rejects any immutable parameter
during updates. It doesn't distinguish between **new** immutable
parameters (first-time use, should be allowed) and **existing** ones
(already set, should be blocked from changing).

## Fix

Added an `isFirstTimeUse` check to the immutable parameter constraint,
matching the logic already used by the interactive prompt path (line
323). New immutable parameters can now be set via `--parameter`, while
existing immutable parameters are still blocked from being changed.

## Testing

Added `TestUpdateValidateRichParameters/NewImmutableParameterViaFlag`
which:
1. Creates a workspace with a mutable parameter
2. Updates the template to add a new immutable parameter
3. Runs `coder update --parameter immutable_param=value`
4. Verifies the update succeeds and the parameter is set correctly

Fixes #22164
2026-02-24 09:15:02 +01:00
Jon Ayers 0a7a3da178 fix: exclude provisioner_state from workspace_build_with_user view (#22159)
The provisioner state for a workspace build was being loaded for every
long-lived agent rpc connection. Since this state can be anywhere from
kilobytes to megabytes this can gradually cause the `coderd` memory
footprint to grow over time. It's also a lot of unnecessary allocations
for every query that fetches a workspace build since only a few callers
ever actually reference the provisioner state.

This PR removes it from the returned workspace build and adds a query to
fetch the provisioner state explicitly.
2026-02-23 22:46:17 -06:00
blinkagent[bot] bf076fb7ee feat: add anthropic and gemini-monochrome icons (#22270)
Adds two new icons to the icon library:

- **`anthropic.svg`** — Anthropic logo
- **`gemini-monochrome.svg`** — Gemini logo, monochrome variant

Both use `monochrome` theme handling to adapt for dark and light
backgrounds.

### Changes
- Added `anthropic.svg` and `gemini-monochrome.svg` to
`site/static/icon/`
- Registered both in `site/src/theme/icons.json` (alphabetically sorted)
- Added `monochrome` theme handling for both in
`site/src/theme/externalImages.ts`

---
Created on behalf of @tracyjohnsonux

---------

Co-authored-by: blink-so[bot] <211532188+blink-so[bot]@users.noreply.github.com>
2026-02-24 13:15:21 +11:00
643 changed files with 72160 additions and 7775 deletions
+249
View File
@@ -0,0 +1,249 @@
# Modern Go (1.181.26)
Reference for writing idiomatic Go. Covers what changed, what it
replaced, and what to reach for. Respect the project's `go.mod` `go`
line: don't emit features from a version newer than what the module
declares. Check `go.mod` before writing code.
## How modern Go thinks differently
**Generics** (1.18): Design reusable code with type parameters instead
of `interface{}` casts, code generation, or the `sort.Interface`
pattern. Use `any` for unconstrained types, `comparable` for map keys
and equality, `cmp.Ordered` for sortable types. Type inference usually
makes explicit type arguments unnecessary (improved in 1.21).
**Per-iteration loop variables** (1.22): Each loop iteration gets its
own variable copy. Closures inside loops capture the correct value. The
`v := v` shadow trick is dead. Remove it when you see it.
**Iterators** (1.23): `iter.Seq[V]` and `iter.Seq2[K,V]` are the
standard iterator types. Containers expose `.All()` methods returning
these. Combined with `slices.Collect`, `slices.Sorted`, `maps.Keys`,
etc., they replace ad-hoc "loop and append" code with composable,
lazy pipelines. When a sequence is consumed only once, prefer an
iterator over materializing a slice.
**Error trees** (1.201.26): Errors compose as trees, not chains.
`errors.Join` aggregates multiple errors. `fmt.Errorf` accepts multiple
`%w` verbs. `errors.Is`/`As` traverse the full tree. Custom error
types that wrap multiple causes must implement `Unwrap() []error` (the
slice form), not `Unwrap() error`, or tree traversal won't find the
children. `errors.AsType[T]` (1.26) is the type-safe way to match
error types. Propagate cancellation reasons with
`context.WithCancelCause`.
**Structured logging** (1.21): `log/slog` is the standard structured
logger. This project uses `cdr.dev/slog/v3` instead, which has a
different API. Do not use `log/slog` directly.
## Replace these patterns
The left column reflects common patterns from pre-1.22 Go. Write the
right column instead. The "Since" column tells you the minimum `go`
directive version required in `go.mod`.
| Old pattern | Modern replacement | Since |
|---|---|---|
| `interface{}` | `any` | 1.18 |
| `v := v` inside loops | remove it | 1.22 |
| `for i := 0; i < n; i++` | `for i := range n` | 1.22 |
| `for i := 0; i < b.N; i++` (benchmarks) | `for b.Loop()` (correct timing, future-proof) | 1.24 |
| `sort.Slice(s, func(i,j int) bool{…})` | `slices.SortFunc(s, cmpFn)` | 1.21 |
| `wg.Add(1); go func(){ defer wg.Done(); … }()` | `wg.Go(func(){…})` | 1.25 |
| `func ptr[T any](v T) *T { return &v }` | `new(expr)` e.g. `new(time.Now())` | 1.26 |
| `var target *E; errors.As(err, &target)` | `t, ok := errors.AsType[*E](err)` | 1.26 |
| Custom multi-error type | `errors.Join(err1, err2, …)` | 1.20 |
| Single `%w` for multiple causes | `fmt.Errorf("…: %w, %w", e1, e2)` | 1.20 |
| `rand.Seed(time.Now().UnixNano())` | delete it (auto-seeded); prefer `math/rand/v2` | 1.20/1.22 |
| `sync.Once` + captured variable | `sync.OnceValue(func() T {…})` / `OnceValues` | 1.21 |
| Custom `min`/`max` helpers | `min(a, b)` / `max(a, b)` builtins (any ordered type) | 1.21 |
| `for k := range m { delete(m, k) }` | `clear(m)` (also zeroes slices) | 1.21 |
| Index+slice or `SplitN(s, sep, 2)` | `strings.Cut(s, sep)` / `bytes.Cut` | 1.18 |
| `TrimPrefix` + check if anything was trimmed | `strings.CutPrefix` / `CutSuffix` (returns ok bool) | 1.20 |
| `strings.Split` + loop when no slice is needed | `strings.SplitSeq` / `Lines` / `FieldsSeq` (iterator, no alloc) | 1.24 |
| `"2006-01-02"` / `"2006-01-02 15:04:05"` / `"15:04:05"` | `time.DateOnly` / `time.DateTime` / `time.TimeOnly` | 1.20 |
| Manual `Before`/`After`/`Equal` chains for comparison | `time.Time.Compare` (returns -1/0/+1; works with `slices.SortFunc`) | 1.20 |
| Loop collecting map keys into slice | `slices.Sorted(maps.Keys(m))` | 1.23 |
| `fmt.Sprintf` + append to `[]byte` | `fmt.Appendf(buf, …)` (also `Append`, `Appendln`) | 1.18 |
| `reflect.TypeOf((*T)(nil)).Elem()` | `reflect.TypeFor[T]()` | 1.22 |
| `*(*[4]byte)(slice)` unsafe cast | `[4]byte(slice)` direct conversion | 1.20 |
| `atomic.LoadInt64` / `StoreInt64` | `atomic.Int64` (also `Bool`, `Uint64`, `Pointer[T]`) | 1.19 |
| `crypto/rand.Read(buf)` + hex/base64 encode | `crypto/rand.Text()` (one call) | 1.24 |
| Checking `crypto/rand.Read` error | don't: return is always nil | 1.24 |
| `time.Sleep` in tests | `testing/synctest` (deterministic fake clock) | 1.24/1.25 |
| `json:",omitempty"` on zero-value structs like `time.Time{}` | `json:",omitzero"` (uses `IsZero()` method) | 1.24 |
| `strings.Title` | `golang.org/x/text/cases` | 1.18 |
| `net.IP` in new code | `net/netip.Addr` (immutable, comparable, lighter) | 1.18 |
| `tools.go` with blank imports | `tool` directive in `go.mod` | 1.24 |
| `runtime.SetFinalizer` | `runtime.AddCleanup` (multiple per object, no pointer cycles) | 1.24 |
| `httputil.ReverseProxy.Director` | `.Rewrite` hook + `ProxyRequest` (Director deprecated in 1.26) | 1.20 |
| `sql.NullString`, `sql.NullInt64`, etc. | `sql.Null[T]` | 1.22 |
| Manual `ctx, cancel := context.WithCancel(…)` + `t.Cleanup(cancel)` | `t.Context()` (auto-canceled when test ends) | 1.24 |
| `if d < 0 { d = -d }` on durations | `d.Abs()` (handles `math.MinInt64`) | 1.19 |
| Implement only `TextMarshaler` | also implement `TextAppender` for alloc-free marshaling | 1.24 |
| Custom `Unwrap() error` on multi-cause errors | `Unwrap() []error` (slice form; required for tree traversal) | 1.20 |
## New capabilities
These enable things that weren't practical before. Reach for them in the
described situations.
| What | Since | When to use it |
|---|---|---|
| `cmp.Or(a, b, c)` | 1.22 | Defaults/fallback chains: returns first non-zero value. Replaces verbose `if a != "" { return a }` cascades. |
| `context.WithoutCancel(ctx)` | 1.21 | Background work that must outlive the request (e.g. async cleanup after HTTP response). Derived context keeps parent's values but ignores cancellation. |
| `context.AfterFunc(ctx, fn)` | 1.21 | Register cleanup that fires on context cancellation without spawning a goroutine that blocks on `<-ctx.Done()`. |
| `context.WithCancelCause` / `Cause` | 1.20 | When callers need to know WHY a context was canceled, not just that it was. Retrieve cause with `context.Cause(ctx)`. |
| `context.WithDeadlineCause` / `WithTimeoutCause` | 1.21 | Attach a domain-specific error to deadline/timeout expiry (e.g. distinguish "DB query timed out" from "HTTP request timed out"). |
| `errors.ErrUnsupported` | 1.21 | Standard sentinel for "not supported." Use instead of per-package custom sentinels. Check with `errors.Is`. |
| `http.ResponseController` | 1.20 | Per-request flush, hijack, and deadline control without type-asserting `ResponseWriter` to `http.Flusher` or `http.Hijacker`. |
| Enhanced `ServeMux` routing | 1.22 | `"GET /items/{id}"` patterns in `http.ServeMux`. Access with `r.PathValue("id")`. Wildcards: `{name}`, catch-all: `{path...}`, exact: `{$}`. Eliminates many third-party router dependencies. |
| `os.Root` / `OpenRoot` | 1.24 | Confined directory access that prevents symlink escape. 1.25 adds `MkdirAll`, `ReadFile`, `WriteFile` for real use. |
| `os.CopyFS` | 1.23 | Copy an entire `fs.FS` to local filesystem in one call. |
| `os/signal.NotifyContext` with cause | 1.26 | Cancellation cause identifies which signal (SIGTERM vs SIGINT) triggered shutdown. |
| `io/fs.SkipAll` / `filepath.SkipAll` | 1.20 | Return from `WalkDir` callback to stop walking entirely. Cleaner than a sentinel error. |
| `GOMEMLIMIT` env / `debug.SetMemoryLimit` | 1.19 | Soft memory limit for GC. Use alongside or instead of `GOGC` in memory-constrained containers. |
| `net/url.JoinPath` | 1.19 | Join URL path segments correctly. Replaces error-prone string concatenation. |
| `go test -skip` | 1.20 | Skip tests matching a pattern. Useful when running a subset of a large test suite. |
## Key packages
### `slices` (1.21, iterators added 1.23)
Replaces `sort.Slice`, manual search loops, and manual contains checks.
Search: `Contains`, `ContainsFunc`, `Index`, `IndexFunc`,
`BinarySearch`, `BinarySearchFunc`.
Sort: `Sort`, `SortFunc`, `SortStableFunc`, `IsSorted`, `IsSortedFunc`,
`Min`, `MinFunc`, `Max`, `MaxFunc`.
Transform: `Clone`, `Compact`, `CompactFunc`, `Grow`, `Clip`,
`Concat` (1.22), `Repeat` (1.23), `Reverse`, `Insert`, `Delete`,
`Replace`.
Compare: `Equal`, `EqualFunc`, `Compare`.
Iterators (1.23): `All`, `Values`, `Backward`, `Collect`, `AppendSeq`,
`Sorted`, `SortedFunc`, `SortedStableFunc`, `Chunk`.
### `maps` (1.21, iterators added 1.23)
Core: `Clone`, `Copy`, `Equal`, `EqualFunc`, `DeleteFunc`.
Iterators (1.23): `All`, `Keys`, `Values`, `Insert`, `Collect`.
### `cmp` (1.21, `Or` added 1.22)
`Ordered` constraint for any ordered type. `Compare(a, b)` returns
-1/0/+1. `Less(a, b)` returns bool. `Or(vals...)` returns first
non-zero value.
### `iter` (1.23)
`Seq[V]` is `func(yield func(V) bool)`. `Seq2[K,V]` is
`func(yield func(K, V) bool)`. Return these from your container's
`.All()` methods. Consume with `for v := range seq` or pass to
`slices.Collect`, `slices.Sorted`, `maps.Collect`, etc.
### `math/rand/v2` (1.22)
Replaces `math/rand`. `IntN` not `Intn`. Generic `N[T]()` for any
integer type. Default source is `ChaCha8` (crypto-quality). No global
`Seed`. Use `rand.New(source)` for reproducible sequences.
### `log/slog` (1.21)
`slog.Info`, `slog.Warn`, `slog.Error`, `slog.Debug` with key-value
pairs. `slog.With(attrs...)` for logger with preset fields.
`slog.GroupAttrs` (1.25) for clean group creation. Implement
`slog.Handler` for custom backends.
**Note:** This project uses `cdr.dev/slog/v3`, not `log/slog`. The
API is different. Read existing code for usage patterns.
## Pitfalls
Things that are easy to get wrong, even when you know the modern API
exists. Check your output against these.
**Version misuse.** The replacement table has a "Since" column. If the
project's `go.mod` says `go 1.22`, you cannot use `wg.Go` (1.25),
`errors.AsType` (1.26), `new(expr)` (1.26), `b.Loop()` (1.24), or
`testing/synctest` (1.24). Fall back to the older pattern. Always
check before reaching for a replacement.
**`slices.Sort` vs `slices.SortFunc`.** `slices.Sort` requires
`cmp.Ordered` types (int, string, float64, etc.). For structs, custom
types, or multi-field sorting, use `slices.SortFunc` with a comparator
function. Using `slices.Sort` on a non-ordered type is a compile error.
**`for range n` still binds the index.** `for range n` discards the
index. If you need it, write `for i := range n`. Writing
`for range n` and then trying to use `i` inside the loop is a compile
error.
**Don't hand-roll iterators when the stdlib returns one.** Functions
like `maps.Keys`, `slices.Values`, `strings.SplitSeq`, and
`strings.Lines` already return `iter.Seq` or `iter.Seq2`. Don't
reimplement them. Compose with `slices.Collect`, `slices.Sorted`, etc.
**Don't mix `math/rand` and `math/rand/v2`.** They have different
function names (`Intn` vs `IntN`) and different default sources. Pick
one per package. Prefer v2 for new code. The v1 global source is
auto-seeded since 1.20, so delete `rand.Seed` calls either way.
**Iterator protocol.** When implementing `iter.Seq`, you must respect
the `yield` return value. If `yield` returns `false`, stop iteration
immediately and return. Ignoring it violates the contract and causes
panics when consumers break out of `for range` loops early.
**`errors.Join` with nil.** `errors.Join` skips nil arguments. This is
intentional and useful for aggregating optional errors, but don't
assume the result is always non-nil. `errors.Join(nil, nil)` returns
nil.
**`cmp.Or` evaluates all arguments.** Unlike a chain of `if`
statements, `cmp.Or(a(), b(), c())` calls all three functions. If any
have side effects or are expensive, use `if`/`else` instead.
**Timer channel semantics changed in 1.23.** Code that checks
`len(timer.C)` to see if a value is pending no longer works (channel
capacity is 0). Use a non-blocking `select` receive instead:
`select { case <-timer.C: default: }`.
**`context.WithoutCancel` still propagates values.** The derived
context inherits all values from the parent. If any middleware stores
request-scoped state (deadlines, trace IDs) via `context.WithValue`,
the background work sees it. This is usually desired but can be
surprising if the values hold references that should not outlive the
request.
## Behavioral changes that affect code
- **Timers** (1.23): unstopped `Timer`/`Ticker` are GC'd immediately.
Channels are unbuffered: no stale values after `Reset`/`Stop`. You no
longer need `defer t.Stop()` to prevent leaks.
- **Error tree traversal** (1.20): `errors.Is`/`As` follow
`Unwrap() []error`, not just `Unwrap() error`. Multi-error types must
expose the slice form for child errors to be found.
- **`math/rand` auto-seeded** (1.20): the global RNG is auto-seeded.
`rand.Seed` is a no-op in 1.24+. Don't call it.
- **GODEBUG compat** (1.21): behavioral changes are gated by `go.mod`'s
`go` line. Upgrading the version opts into new defaults.
- **Build tags** (1.18): `//go:build` is the only syntax. `// +build`
is gone.
- **Tool install** (1.18): `go get` no longer builds. Use
`go install pkg@version`.
- **Doc comments** (1.19): support `[links]`, lists, and headings.
- **`go test -skip`** (1.20): skip tests by name pattern from the
command line.
- **`go fix ./...` modernizers** (1.26): auto-rewrites code to use
newer idioms. Run after Go version upgrades.
## Transparent improvements (no code changes)
Swiss Tables maps, Green Tea GC, PGO, faster `io.ReadAll`,
stack-allocated slices, reduced cgo overhead, container-aware
GOMAXPROCS. Free on upgrade.
+2 -5
View File
@@ -5,9 +5,6 @@ inputs:
version:
description: "The Go version to use."
default: "1.25.7"
use-preinstalled-go:
description: "Whether to use preinstalled Go."
default: "false"
use-cache:
description: "Whether to use the cache."
default: "true"
@@ -15,9 +12,9 @@ runs:
using: "composite"
steps:
- name: Setup Go
uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5.6.0
with:
go-version: ${{ inputs.use-preinstalled-go == 'false' && inputs.version || '' }}
go-version: ${{ inputs.version }}
cache: ${{ inputs.use-cache }}
- name: Install gotestsum
+2 -97
View File
@@ -422,10 +422,6 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
with:
# Runners have Go baked-in and Go will automatically
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
use-cache: true
- name: Setup Terraform
@@ -1042,83 +1038,6 @@ jobs:
echo "Required checks have passed"
# Builds the dylibs and upload it as an artifact so it can be embedded in the main build
build-dylib:
needs: changes
# We always build the dylibs on Go changes to verify we're not merging unbuildable code,
# but they need only be signed and uploaded on coder/coder main.
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
runs-on: ${{ github.repository_owner == 'coder' && 'depot-macos-latest' || 'macos-latest' }}
steps:
# Harden Runner doesn't work on macOS
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup GNU tools (macOS)
uses: ./.github/actions/setup-gnu-tools
- name: Switch XCode Version
uses: maxim-lobanov/setup-xcode@60606e260d2fc5762a71e64e74b2174e8ea3c8bd # v1.6.0
with:
xcode-version: "16.1.0"
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Install rcodesign
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
run: |
set -euo pipefail
wget -O /tmp/rcodesign.tar.gz https://github.com/indygreg/apple-platform-rs/releases/download/apple-codesign%2F0.22.0/apple-codesign-0.22.0-macos-universal.tar.gz
sudo tar -xzf /tmp/rcodesign.tar.gz \
-C /usr/local/bin \
--strip-components=1 \
apple-codesign-0.22.0-macos-universal/rcodesign
rm /tmp/rcodesign.tar.gz
- name: Setup Apple Developer certificate and API key
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
run: |
set -euo pipefail
touch /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
chmod 600 /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
echo "$AC_CERTIFICATE_P12_BASE64" | base64 -d > /tmp/apple_cert.p12
echo "$AC_CERTIFICATE_PASSWORD" > /tmp/apple_cert_password.txt
echo "$AC_APIKEY_P8_BASE64" | base64 -d > /tmp/apple_apikey.p8
env:
AC_CERTIFICATE_P12_BASE64: ${{ secrets.AC_CERTIFICATE_P12_BASE64 }}
AC_CERTIFICATE_PASSWORD: ${{ secrets.AC_CERTIFICATE_PASSWORD }}
AC_APIKEY_P8_BASE64: ${{ secrets.AC_APIKEY_P8_BASE64 }}
- name: Build dylibs
run: |
set -euxo pipefail
./.github/scripts/retry.sh -- go mod download
make gen/mark-fresh
make build/coder-dylib
env:
CODER_SIGN_DARWIN: ${{ (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) && '1' || '0' }}
AC_CERTIFICATE_FILE: /tmp/apple_cert.p12
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
- name: Upload build artifacts
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: dylibs
path: |
./build/*.h
./build/*.dylib
retention-days: 7
- name: Delete Apple Developer certificate and API key
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
run: rm -f /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
check-build:
# This job runs make build to verify compilation on PRs.
# The build doesn't get signed, and is not suitable for usage, unlike the
@@ -1165,7 +1084,6 @@ jobs:
# to main branch.
needs:
- changes
- build-dylib
if: (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) && needs.changes.outputs.docs-only == 'false' && !github.event.pull_request.head.repo.fork
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-22.04' }}
permissions:
@@ -1271,18 +1189,6 @@ jobs:
- name: Setup GCloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Download dylibs
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0
with:
name: dylibs
path: ./build
- name: Insert dylibs
run: |
mv ./build/*amd64.dylib ./site/out/bin/coder-vpn-darwin-amd64.dylib
mv ./build/*arm64.dylib ./site/out/bin/coder-vpn-darwin-arm64.dylib
mv ./build/*arm64.h ./site/out/bin/coder-vpn-darwin-dylib.h
- name: Build
run: |
set -euxo pipefail
@@ -1298,9 +1204,8 @@ jobs:
build/coder_"$version"_windows_amd64.zip \
build/coder_"$version"_linux_amd64.{tar.gz,deb}
env:
# The Windows slim binary must be signed for Coder Desktop to accept
# it. The darwin executables don't need to be signed, but the dylibs
# do (see above).
# The Windows and Darwin slim binaries must be signed for Coder
# Desktop to accept them.
CODER_SIGN_WINDOWS: "1"
CODER_WINDOWS_RESOURCES: "1"
CODER_SIGN_GPG: "1"
+23
View File
@@ -0,0 +1,23 @@
# This workflow triggers a Vercel deploy hook which builds+deploys coder.com
# (a Next.js app), to keep coder.com/docs URLs in sync with docs/manifest.json
#
# https://vercel.com/docs/deploy-hooks#triggering-a-deploy-hook
name: Update coder.com/docs
on:
push:
branches:
- main
paths:
- "docs/manifest.json"
permissions: {}
jobs:
deploy-docs:
runs-on: ubuntu-latest
steps:
- name: Deploy docs site
run: |
curl -X POST "${{ secrets.DEPLOY_DOCS_VERCEL_WEBHOOK }}"
-304
View File
@@ -1,304 +0,0 @@
# This workflow creates a Coder Task to fix a flaky test. It is triggered by
# the flake-investigator bot (via repository_dispatch) after it triages a CI
# failure and creates a flake issue in coder/internal, or manually via
# workflow_dispatch.
#
# The flake issue contains the investigation and root cause analysis. The Task
# reads the issue, implements a fix, verifies it, and opens a PR.
#
# Triggers:
# - repository_dispatch (type: flake-fix): Automated trigger from flake-investigator
# - workflow_dispatch: Manual trigger with flake issue details
name: Flake Fix
on:
repository_dispatch:
types: [flake-fix]
workflow_dispatch:
inputs:
issue_url:
description: "Flake issue URL (in coder/internal)"
required: true
type: string
template_preset:
description: "Template preset to use"
required: false
default: ""
type: string
jobs:
flake-fix:
name: Fix Flaky Test
runs-on: ubuntu-latest
timeout-minutes: 30
env:
CODER_URL: ${{ secrets.FLAKE_BOT_CODER_URL }}
CODER_SESSION_TOKEN: ${{ secrets.FLAKE_BOT_CODER_SESSION_TOKEN }}
permissions:
contents: read
pull-requests: write
actions: write
steps:
- name: Check if secrets are available
id: check-secrets
env:
CODER_URL: ${{ secrets.FLAKE_BOT_CODER_URL }}
CODER_TOKEN: ${{ secrets.FLAKE_BOT_CODER_SESSION_TOKEN }}
run: |
if [[ -z "${CODER_URL}" || -z "${CODER_TOKEN}" ]]; then
echo "skip=true" >> "${GITHUB_OUTPUT}"
echo "Secrets not available - skipping flake fix."
{
echo "⚠️ Workflow skipped: Secrets not available"
echo ""
echo "This workflow requires FLAKE_BOT_CODER_URL and FLAKE_BOT_CODER_SESSION_TOKEN."
} >> "${GITHUB_STEP_SUMMARY}"
else
echo "skip=false" >> "${GITHUB_OUTPUT}"
fi
- name: Setup Coder CLI
if: steps.check-secrets.outputs.skip != 'true'
uses: coder/setup-action@4a607a8113d4e676e2d7c34caa20a814bc88bfda # v1
with:
access_url: ${{ secrets.FLAKE_BOT_CODER_URL }}
coder_session_token: ${{ secrets.FLAKE_BOT_CODER_SESSION_TOKEN }}
- name: Determine Inputs
if: steps.check-secrets.outputs.skip != 'true'
id: determine-inputs
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
DISPATCH_ISSUE_URL: ${{ github.event.client_payload.issue_url }}
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
INPUTS_TEMPLATE_PRESET: ${{ inputs.template_preset || '' }}
run: |
if [[ "${GITHUB_EVENT_NAME}" == "repository_dispatch" ]]; then
ISSUE_URL="${DISPATCH_ISSUE_URL}"
TEMPLATE_PRESET=""
elif [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ]]; then
ISSUE_URL="${INPUTS_ISSUE_URL}"
TEMPLATE_PRESET="${INPUTS_TEMPLATE_PRESET}"
else
echo "::error::Unsupported event type: ${GITHUB_EVENT_NAME}"
exit 1
fi
if [[ -z "${ISSUE_URL}" ]]; then
echo "::error::Issue URL is required"
exit 1
fi
echo "issue_url=${ISSUE_URL}" >> "${GITHUB_OUTPUT}"
echo "template_preset=${TEMPLATE_PRESET}" >> "${GITHUB_OUTPUT}"
echo "Fixing flake from issue: ${ISSUE_URL}"
- name: Build Task Prompt
if: steps.check-secrets.outputs.skip != 'true'
id: build-prompt
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
run: |
TASK_PROMPT=$(cat <<'EOF'
Fix the flaky test described in ISSUE_URL_PLACEHOLDER
Use the gh CLI to read the issue which contains the investigation and root cause analysis.
Fix requirements:
- Fix the root cause identified in the issue.
- Never suppress or skip the test.
When complete:
1. Verify by running the test multiple times.
2. Commit with format: `fix(test): resolve flaky TestName`
3. Push and create a PR using gh CLI linking to the flake issue.
EOF
)
TASK_PROMPT="${TASK_PROMPT//ISSUE_URL_PLACEHOLDER/${ISSUE_URL}}"
{
echo "task_prompt<<EOFOUTPUT"
echo "${TASK_PROMPT}"
echo "EOFOUTPUT"
} >> "${GITHUB_OUTPUT}"
- name: Checkout create-task-action
if: steps.check-secrets.outputs.skip != 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 1
path: ./.github/actions/create-task-action
persist-credentials: false
ref: main
repository: coder/create-task-action
- name: Create Coder Task for Flake Fix
if: steps.check-secrets.outputs.skip != 'true'
id: create_task
uses: ./.github/actions/create-task-action
with:
coder-url: ${{ secrets.FLAKE_BOT_CODER_URL }}
coder-token: ${{ secrets.FLAKE_BOT_CODER_SESSION_TOKEN }}
coder-organization: "default"
coder-template-name: coder-workflow-bot
coder-template-preset: ${{ steps.determine-inputs.outputs.template_preset }}
coder-task-name-prefix: flake-fix
coder-task-prompt: ${{ steps.build-prompt.outputs.task_prompt }}
coder-username: flake-bot
github-token: ${{ github.token }}
github-issue-url: ${{ steps.determine-inputs.outputs.issue_url }}
comment-on-issue: false
- name: Write Task Info
if: steps.check-secrets.outputs.skip != 'true'
env:
TASK_CREATED: ${{ steps.create_task.outputs.task-created }}
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
TASK_URL: ${{ steps.create_task.outputs.task-url }}
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
run: |
{
echo "## Flake Fix Task"
echo ""
echo "**Issue:** ${ISSUE_URL}"
echo "**Task created:** ${TASK_CREATED}"
echo "**Task name:** ${TASK_NAME}"
echo "**Task URL:** ${TASK_URL}"
echo ""
} >> "${GITHUB_STEP_SUMMARY}"
- name: Wait for Task Completion
if: steps.check-secrets.outputs.skip != 'true'
id: wait_task
env:
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
run: |
echo "Waiting for task to complete..."
echo "Task name: ${TASK_NAME}"
if [[ -z "${TASK_NAME}" ]]; then
echo "::error::TASK_NAME is empty"
exit 1
fi
MAX_WAIT=1200 # 20 minutes
WAITED=0
POLL_INTERVAL=5
LAST_STATUS=""
is_workspace_message() {
local msg="$1"
[[ -z "$msg" ]] && return 0
[[ "$msg" =~ ^Workspace ]] && return 0
[[ "$msg" =~ ^Agent ]] && return 0
return 1
}
while [[ $WAITED -lt $MAX_WAIT ]]; do
RAW_OUTPUT=$(coder task status "${TASK_NAME}" -o json 2>&1) || true
STATUS_JSON=$(echo "$RAW_OUTPUT" | grep -v "^version mismatch\|^download v" || true)
if [[ $WAITED -eq 0 ]]; then
echo "Raw status output: ${RAW_OUTPUT:0:500}"
fi
if [[ -z "$STATUS_JSON" ]] || ! echo "$STATUS_JSON" | jq -e . >/dev/null 2>&1; then
if [[ "$LAST_STATUS" != "waiting" ]]; then
echo "[${WAITED}s] Waiting for task status..."
LAST_STATUS="waiting"
fi
sleep $POLL_INTERVAL
WAITED=$((WAITED + POLL_INTERVAL))
continue
fi
TASK_STATE=$(echo "$STATUS_JSON" | jq -r '.current_state.state // "unknown"')
TASK_MESSAGE=$(echo "$STATUS_JSON" | jq -r '.current_state.message // ""')
WORKSPACE_STATUS=$(echo "$STATUS_JSON" | jq -r '.workspace_status // "unknown"')
CURRENT_STATUS="${TASK_STATE}|${WORKSPACE_STATUS}|${TASK_MESSAGE}"
if [[ "$CURRENT_STATUS" != "$LAST_STATUS" ]]; then
if [[ "$TASK_STATE" == "idle" ]] && is_workspace_message "$TASK_MESSAGE"; then
echo "[${WAITED}s] Workspace ready, waiting for Agent..."
else
echo "[${WAITED}s] State: ${TASK_STATE} | Workspace: ${WORKSPACE_STATUS} | ${TASK_MESSAGE}"
fi
LAST_STATUS="$CURRENT_STATUS"
fi
if [[ "$WORKSPACE_STATUS" == "failed" || "$WORKSPACE_STATUS" == "canceled" ]]; then
echo "::error::Workspace failed: ${WORKSPACE_STATUS}"
exit 1
fi
if [[ "$TASK_STATE" == "idle" ]]; then
if ! is_workspace_message "$TASK_MESSAGE"; then
echo ""
echo "Task completed: ${TASK_MESSAGE}"
RESULT_URI=$(echo "$STATUS_JSON" | jq -r '.current_state.uri // ""')
echo "result_uri=${RESULT_URI}" >> "${GITHUB_OUTPUT}"
echo "task_message=${TASK_MESSAGE}" >> "${GITHUB_OUTPUT}"
break
fi
fi
sleep $POLL_INTERVAL
WAITED=$((WAITED + POLL_INTERVAL))
done
if [[ $WAITED -ge $MAX_WAIT ]]; then
echo "::error::Task monitoring timed out after ${MAX_WAIT}s"
exit 1
fi
- name: Fetch Task Logs
if: always() && steps.check-secrets.outputs.skip != 'true'
env:
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
run: |
echo "::group::Task Conversation Log"
if [[ -n "${TASK_NAME}" ]]; then
coder task logs "${TASK_NAME}" 2>&1 || echo "Failed to fetch logs"
else
echo "No task name, skipping log fetch"
fi
echo "::endgroup::"
- name: Cleanup Task
if: always() && steps.check-secrets.outputs.skip != 'true'
env:
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
run: |
if [[ -n "${TASK_NAME}" ]]; then
echo "Deleting task: ${TASK_NAME}"
coder task delete "${TASK_NAME}" -y 2>&1 || echo "Task deletion failed or already deleted"
else
echo "No task name, skipping cleanup"
fi
- name: Write Final Summary
if: always() && steps.check-secrets.outputs.skip != 'true'
env:
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
TASK_MESSAGE: ${{ steps.wait_task.outputs.task_message }}
RESULT_URI: ${{ steps.wait_task.outputs.result_uri }}
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
run: |
{
echo ""
echo "---"
echo "### Result"
echo ""
echo "**Issue:** ${ISSUE_URL}"
echo "**Status:** ${TASK_MESSAGE:-Task completed}"
if [[ -n "${RESULT_URI}" ]]; then
echo "**Details:** ${RESULT_URI}"
fi
echo ""
echo "Task \`${TASK_NAME}\` has been cleaned up."
} >> "${GITHUB_STEP_SUMMARY}"
-5
View File
@@ -64,11 +64,6 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
with:
# Runners have Go baked-in and Go will automatically
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
- name: Setup Terraform
uses: ./.github/actions/setup-tf
+1 -123
View File
@@ -58,87 +58,9 @@ jobs:
if (!allowed) core.setFailed('Denied: requires maintain or admin');
# build-dylib is a separate job to build the dylib on macOS.
build-dylib:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-macos-latest' || 'macos-latest' }}
needs: check-perms
steps:
# Harden Runner doesn't work on macOS.
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
# If the event that triggered the build was an annotated tag (which our
# tags are supposed to be), actions/checkout has a bug where the tag in
# question is only a lightweight tag and not a full annotated tag. This
# command seems to fix it.
# https://github.com/actions/checkout/issues/290
- name: Fetch git tags
run: git fetch --tags --force
- name: Setup GNU tools (macOS)
uses: ./.github/actions/setup-gnu-tools
- name: Switch XCode Version
uses: maxim-lobanov/setup-xcode@60606e260d2fc5762a71e64e74b2174e8ea3c8bd # v1.6.0
with:
xcode-version: "16.1.0"
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Install rcodesign
run: |
set -euo pipefail
wget -O /tmp/rcodesign.tar.gz https://github.com/indygreg/apple-platform-rs/releases/download/apple-codesign%2F0.22.0/apple-codesign-0.22.0-macos-universal.tar.gz
sudo tar -xzf /tmp/rcodesign.tar.gz \
-C /usr/local/bin \
--strip-components=1 \
apple-codesign-0.22.0-macos-universal/rcodesign
rm /tmp/rcodesign.tar.gz
- name: Setup Apple Developer certificate and API key
run: |
set -euo pipefail
touch /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
chmod 600 /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
echo "$AC_CERTIFICATE_P12_BASE64" | base64 -d > /tmp/apple_cert.p12
echo "$AC_CERTIFICATE_PASSWORD" > /tmp/apple_cert_password.txt
echo "$AC_APIKEY_P8_BASE64" | base64 -d > /tmp/apple_apikey.p8
env:
AC_CERTIFICATE_P12_BASE64: ${{ secrets.AC_CERTIFICATE_P12_BASE64 }}
AC_CERTIFICATE_PASSWORD: ${{ secrets.AC_CERTIFICATE_PASSWORD }}
AC_APIKEY_P8_BASE64: ${{ secrets.AC_APIKEY_P8_BASE64 }}
- name: Build dylibs
run: |
set -euxo pipefail
./.github/scripts/retry.sh -- go mod download
make gen/mark-fresh
make build/coder-dylib
env:
CODER_SIGN_DARWIN: 1
AC_CERTIFICATE_FILE: /tmp/apple_cert.p12
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
- name: Upload build artifacts
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: dylibs
path: |
./build/*.h
./build/*.dylib
retention-days: 7
- name: Delete Apple Developer certificate and API key
run: rm -f /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
release:
name: Build and publish
needs: [build-dylib, check-perms]
needs: [check-perms]
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
permissions:
# Required to publish a release
@@ -320,18 +242,6 @@ jobs:
- name: Setup GCloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Download dylibs
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0
with:
name: dylibs
path: ./build
- name: Insert dylibs
run: |
mv ./build/*amd64.dylib ./site/out/bin/coder-vpn-darwin-amd64.dylib
mv ./build/*arm64.dylib ./site/out/bin/coder-vpn-darwin-arm64.dylib
mv ./build/*arm64.h ./site/out/bin/coder-vpn-darwin-dylib.h
- name: Build binaries
run: |
set -euo pipefail
@@ -955,35 +865,3 @@ jobs:
# different repo.
GH_TOKEN: ${{ secrets.CDRCI_GITHUB_TOKEN }}
VERSION: ${{ needs.release.outputs.version }}
# publish-sqlc pushes the latest schema to sqlc cloud.
# At present these pushes cannot be tagged, so the last push is always the latest.
publish-sqlc:
name: "Publish to schema sqlc cloud"
runs-on: "ubuntu-latest"
needs: release
if: ${{ !inputs.dry_run }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 1
persist-credentials: false
# We need golang to run the migration main.go
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Setup sqlc
uses: ./.github/actions/setup-sqlc
- name: Push schema to sqlc cloud
# Don't block a release on this
continue-on-error: true
run: |
make sqlc-push
+2 -3
View File
@@ -198,13 +198,12 @@ reviewer time and clutters the diff.
**Don't delete existing comments** that explain non-obvious behavior. These
comments preserve important context about why code works a certain way.
**When adding tests for new behavior**, add new test cases instead of modifying
existing ones. This preserves coverage for the original behavior and makes it
clear what the new test covers.
**When adding tests for new behavior**, read existing tests first to understand what's covered. Add new cases for uncovered behavior. Edit existing tests as needed, but don't change what they verify.
## Detailed Development Guides
@.claude/docs/ARCHITECTURE.md
@.claude/docs/GO.md
@.claude/docs/OAUTH2.md
@.claude/docs/TESTING.md
@.claude/docs/TROUBLESHOOTING.md
+30 -32
View File
@@ -94,12 +94,8 @@ PACKAGE_OS_ARCHES := linux_amd64 linux_armv7 linux_arm64
# All architectures we build Docker images for (Linux only).
DOCKER_ARCHES := amd64 arm64 armv7
# All ${OS}_${ARCH} combos we build the desktop dylib for.
DYLIB_ARCHES := darwin_amd64 darwin_arm64
# Computed variables based on the above.
CODER_SLIM_BINARIES := $(addprefix build/coder-slim_$(VERSION)_,$(OS_ARCHES))
CODER_DYLIBS := $(foreach os_arch, $(DYLIB_ARCHES), build/coder-vpn_$(VERSION)_$(os_arch).dylib)
CODER_FAT_BINARIES := $(addprefix build/coder_$(VERSION)_,$(OS_ARCHES))
CODER_ALL_BINARIES := $(CODER_SLIM_BINARIES) $(CODER_FAT_BINARIES)
CODER_TAR_GZ_ARCHIVES := $(foreach os_arch, $(ARCHIVE_TAR_GZ), build/coder_$(VERSION)_$(os_arch).tar.gz)
@@ -261,26 +257,6 @@ $(CODER_ALL_BINARIES): go.mod go.sum \
fi
fi
# This task builds Coder Desktop dylibs
$(CODER_DYLIBS): go.mod go.sum $(MOST_GO_SRC_FILES)
@if [ "$(shell uname)" = "Darwin" ]; then
$(get-mode-os-arch-ext)
./scripts/build_go.sh \
--os "$$os" \
--arch "$$arch" \
--version "$(VERSION)" \
--output "$@" \
--dylib
else
echo "ERROR: Can't build dylib on non-Darwin OS" 1>&2
exit 1
fi
# This task builds both dylibs
build/coder-dylib: $(CODER_DYLIBS)
.PHONY: build/coder-dylib
# This task builds all archives. It parses the target name to get the metadata
# for the build, so it must be specified in this format:
# build/coder_${version}_${os}_${arch}.${format}
@@ -427,6 +403,7 @@ SITE_GEN_FILES := \
site/src/api/typesGenerated.ts \
site/src/api/rbacresourcesGenerated.ts \
site/src/api/countriesGenerated.ts \
site/src/api/chatModelOptionsGenerated.json \
site/src/theme/icons.json
site/out/index.html: \
@@ -654,6 +631,7 @@ GEN_FILES := \
tailnet/proto/tailnet.pb.go \
agent/proto/agent.pb.go \
agent/agentsocket/proto/agentsocket.pb.go \
agent/boundarylogproxy/codec/boundary.pb.go \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
@@ -709,6 +687,7 @@ gen/mark-fresh:
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
agent/agentsocket/proto/agentsocket.pb.go \
agent/boundarylogproxy/codec/boundary.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
coderd/database/dump.sql \
@@ -719,6 +698,7 @@ gen/mark-fresh:
coderd/rbac/scopes_constants_gen.go \
site/src/api/rbacresourcesGenerated.ts \
site/src/api/countriesGenerated.ts \
site/src/api/chatModelOptionsGenerated.json \
docs/admin/integrations/prometheus.md \
docs/reference/cli/index.md \
docs/admin/security/audit-logs.md \
@@ -843,6 +823,12 @@ vpn/vpn.pb.go: vpn/vpn.proto
--go_opt=paths=source_relative \
./vpn/vpn.proto
agent/boundarylogproxy/codec/boundary.pb.go: agent/boundarylogproxy/codec/boundary.proto agent/proto/agent.proto
protoc \
--go_out=. \
--go_opt=paths=source_relative \
./agent/boundarylogproxy/codec/boundary.proto
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
protoc \
--go_out=. \
@@ -854,7 +840,7 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# -C sets the directory for the go run command
go run -C ./scripts/apitypings main.go > $@
(cd site/ && pnpm exec biome format --write src/api/typesGenerated.ts)
./scripts/biome_format.sh src/api/typesGenerated.ts
touch "$@"
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
@@ -863,7 +849,7 @@ site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/prot
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
go run ./scripts/gensite/ -icons "$@"
(cd site/ && pnpm exec biome format --write src/theme/icons.json)
./scripts/biome_format.sh src/theme/icons.json
touch "$@"
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
@@ -901,14 +887,18 @@ codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scope
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
go run scripts/typegen/main.go rbac typescript > "$@"
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
./scripts/biome_format.sh src/api/rbacresourcesGenerated.ts
touch "$@"
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
go run scripts/typegen/main.go countries > "$@"
(cd site/ && pnpm exec biome format --write src/api/countriesGenerated.ts)
./scripts/biome_format.sh src/api/countriesGenerated.ts
touch "$@"
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go
go run ./scripts/modeloptionsgen/main.go | tail -n +2 > "$@"
cd site && pnpm biome format --write src/api/chatModelOptionsGenerated.json
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
go run ./scripts/metricsdocgen/scanner > $@
@@ -950,11 +940,11 @@ coderd/apidoc/.gen: \
touch "$@"
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
(cd site/ && pnpm exec biome format --write ../docs/manifest.json)
./scripts/biome_format.sh ../docs/manifest.json
touch "$@"
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
(cd site/ && pnpm exec biome format --write ../coderd/apidoc/swagger.json)
./scripts/biome_format.sh ../coderd/apidoc/swagger.json
touch "$@"
update-golden-files:
@@ -999,11 +989,19 @@ enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/
touch "$@"
helm/coder/tests/testdata/.gen-golden: $(wildcard helm/coder/tests/testdata/*.yaml) $(wildcard helm/coder/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/coder/tests/*_test.go)
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/coder golden generation" >&2
fi
touch "$@"
helm/provisioner/tests/testdata/.gen-golden: $(wildcard helm/provisioner/tests/testdata/*.yaml) $(wildcard helm/provisioner/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/provisioner/tests/*_test.go)
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/provisioner golden generation" >&2
fi
touch "$@"
coderd/.gen-golden: $(wildcard coderd/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard coderd/*_test.go)
+15 -3
View File
@@ -41,6 +41,7 @@ import (
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentfiles"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agentssh"
@@ -302,7 +303,8 @@ type agent struct {
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API
filesAPI *agentfiles.API
filesAPI *agentfiles.API
processAPI *agentproc.API
socketServerEnabled bool
socketPath string
@@ -375,6 +377,7 @@ func (a *agent) init() {
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem)
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv)
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
@@ -407,7 +410,7 @@ func (a *agent) initSocketServer() {
agentsocket.WithPath(a.socketPath),
)
if err != nil {
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err), slog.F("path", a.socketPath))
a.logger.Error(a.hardCtx, "failed to create socket server", slog.Error(err), slog.F("path", a.socketPath))
return
}
@@ -417,7 +420,12 @@ func (a *agent) initSocketServer() {
// startBoundaryLogProxyServer starts the boundary log proxy socket server.
func (a *agent) startBoundaryLogProxyServer() {
proxy := boundarylogproxy.NewServer(a.logger, a.boundaryLogProxySocketPath)
if a.boundaryLogProxySocketPath == "" {
a.logger.Warn(a.hardCtx, "boundary log proxy socket path not defined; not starting proxy")
return
}
proxy := boundarylogproxy.NewServer(a.logger, a.boundaryLogProxySocketPath, a.prometheusRegistry)
if err := proxy.Start(); err != nil {
a.logger.Warn(a.hardCtx, "failed to start boundary log proxy", slog.Error(err))
return
@@ -2030,6 +2038,10 @@ func (a *agent) Close() error {
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
}
if err := a.processAPI.Close(); err != nil {
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
}
if a.boundaryLogProxy != nil {
err = a.boundaryLogProxy.Close()
if err != nil {
+1
View File
@@ -29,6 +29,7 @@ func (api *API) Routes() http.Handler {
r.Post("/list-directory", api.HandleLS)
r.Get("/read-file", api.HandleReadFile)
r.Get("/read-file-lines", api.HandleReadFileLines)
r.Post("/write-file", api.HandleWriteFile)
r.Post("/edit-files", api.HandleEditFiles)
+283 -7
View File
@@ -10,11 +10,10 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
"github.com/icholy/replace"
"github.com/spf13/afero"
"golang.org/x/text/transform"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
@@ -23,6 +22,22 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// ReadFileLinesResponse is the JSON response for the line-based file reader.
type ReadFileLinesResponse struct {
// Success indicates whether the read was successful.
Success bool `json:"success"`
// FileSize is the original file size in bytes.
FileSize int64 `json:"file_size,omitempty"`
// TotalLines is the total number of lines in the file.
TotalLines int `json:"total_lines,omitempty"`
// LinesRead is the count of lines returned in this response.
LinesRead int `json:"lines_read,omitempty"`
// Content is the line-numbered file content.
Content string `json:"content,omitempty"`
// Error is the error message when success is false.
Error string `json:"error,omitempty"`
}
type HTTPResponseCode = int
func (api *API) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
@@ -103,6 +118,166 @@ func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path str
return 0, nil
}
func (api *API) HandleReadFileLines(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
path := parser.String(query, "", "path")
offset := parser.PositiveInt64(query, 1, "offset")
limit := parser.PositiveInt64(query, 0, "limit")
maxFileSize := parser.PositiveInt64(query, workspacesdk.DefaultMaxFileSize, "max_file_size")
maxLineBytes := parser.PositiveInt64(query, workspacesdk.DefaultMaxLineBytes, "max_line_bytes")
maxResponseLines := parser.PositiveInt64(query, workspacesdk.DefaultMaxResponseLines, "max_response_lines")
maxResponseBytes := parser.PositiveInt64(query, workspacesdk.DefaultMaxResponseBytes, "max_response_bytes")
parser.ErrorExcessParams(query)
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: parser.Errors,
})
return
}
resp := api.readFileLines(ctx, path, offset, limit, workspacesdk.ReadFileLinesLimits{
MaxFileSize: maxFileSize,
MaxLineBytes: int(maxLineBytes),
MaxResponseLines: int(maxResponseLines),
MaxResponseBytes: int(maxResponseBytes),
})
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
func (api *API) readFileLines(_ context.Context, path string, offset, limit int64, limits workspacesdk.ReadFileLinesLimits) ReadFileLinesResponse {
errResp := func(msg string) ReadFileLinesResponse {
return ReadFileLinesResponse{Success: false, Error: msg}
}
if !filepath.IsAbs(path) {
return errResp(fmt.Sprintf("file path must be absolute: %q", path))
}
f, err := api.filesystem.Open(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return errResp(fmt.Sprintf("file does not exist: %s", path))
}
if errors.Is(err, os.ErrPermission) {
return errResp(fmt.Sprintf("permission denied: %s", path))
}
return errResp(fmt.Sprintf("open file: %s", err))
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return errResp(fmt.Sprintf("stat file: %s", err))
}
if stat.IsDir() {
return errResp(fmt.Sprintf("not a file: %s", path))
}
fileSize := stat.Size()
if fileSize > limits.MaxFileSize {
return errResp(fmt.Sprintf(
"file is %d bytes which exceeds the maximum of %d bytes. Use grep, sed, or awk to extract the content you need, or use offset and limit to read a portion.",
fileSize, limits.MaxFileSize,
))
}
// Read the entire file (up to MaxFileSize).
data, err := io.ReadAll(f)
if err != nil {
return errResp(fmt.Sprintf("read file: %s", err))
}
// Split into lines.
content := string(data)
// Handle empty file.
if content == "" {
return ReadFileLinesResponse{
Success: true,
FileSize: fileSize,
TotalLines: 0,
LinesRead: 0,
Content: "",
}
}
lines := strings.Split(content, "\n")
totalLines := len(lines)
// offset is 1-based line number.
if offset < 1 {
offset = 1
}
if offset > int64(totalLines) {
return errResp(fmt.Sprintf(
"offset %d is beyond the file length of %d lines",
offset, totalLines,
))
}
// Default limit.
if limit <= 0 {
limit = int64(limits.MaxResponseLines)
}
startIdx := int(offset - 1) // convert to 0-based
endIdx := startIdx + int(limit)
if endIdx > totalLines {
endIdx = totalLines
}
var numbered []string
totalBytesAccumulated := 0
for i := startIdx; i < endIdx; i++ {
line := lines[i]
// Per-line truncation.
if len(line) > limits.MaxLineBytes {
line = line[:limits.MaxLineBytes] + "... [truncated]"
}
// Format with 1-based line number.
numberedLine := fmt.Sprintf("%d\t%s", i+1, line)
lineBytes := len(numberedLine)
// Check total byte budget.
newTotal := totalBytesAccumulated + lineBytes
if len(numbered) > 0 {
newTotal++ // account for \n joiner
}
if newTotal > limits.MaxResponseBytes {
return errResp(fmt.Sprintf(
"output would exceed %d bytes. Read less at a time using offset and limit parameters.",
limits.MaxResponseBytes,
))
}
// Check line count.
if len(numbered) >= limits.MaxResponseLines {
return errResp(fmt.Sprintf(
"output would exceed %d lines. Read less at a time using offset and limit parameters.",
limits.MaxResponseLines,
))
}
numbered = append(numbered, numberedLine)
totalBytesAccumulated = newTotal
}
return ReadFileLinesResponse{
Success: true,
FileSize: fileSize,
TotalLines: totalLines,
LinesRead: len(numbered),
Content: strings.Join(numbered, "\n"),
}
}
func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@@ -245,9 +420,21 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
}
transforms := make([]transform.Transformer, len(edits))
for i, edit := range edits {
transforms[i] = replace.String(edit.Search, edit.Replace)
data, err := io.ReadAll(f)
if err != nil {
return http.StatusInternalServerError, xerrors.Errorf("read %s: %w", path, err)
}
content := string(data)
for _, edit := range edits {
var ok bool
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
if !ok {
api.logger.Warn(ctx, "edit search string not found, skipping",
slog.F("path", path),
slog.F("search_preview", truncate(edit.Search, 64)),
)
}
}
// Create an adjacent file to ensure it will be on the same device and can be
@@ -258,8 +445,7 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
}
defer tmpfile.Close()
_, err = io.Copy(tmpfile, replace.Chain(f, transforms...))
if err != nil {
if _, err := tmpfile.Write([]byte(content)); err != nil {
if rerr := api.filesystem.Remove(tmpfile.Name()); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
@@ -273,3 +459,93 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
return 0, nil
}
// fuzzyReplace attempts to find `search` inside `content` and replace its first
// occurrence with `replace`. It uses a cascading match strategy inspired by
// openai/codex's apply_patch:
//
// 1. Exact substring match (byte-for-byte).
// 2. Line-by-line match ignoring trailing whitespace on each line.
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
//
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
// at the byte offsets of the original content so that surrounding text (including
// indentation of untouched lines) is preserved.
//
// Returns the (possibly modified) content and a bool indicating whether a match
// was found.
func fuzzyReplace(content, search, replace string) (string, bool) {
// Pass 1 exact substring (replace all occurrences).
if strings.Contains(content, search) {
return strings.ReplaceAll(content, search, replace), true
}
// For line-level fuzzy matching we split both content and search into lines.
contentLines := strings.SplitAfter(content, "\n")
searchLines := strings.SplitAfter(search, "\n")
// A trailing newline in the search produces an empty final element from
// SplitAfter. Drop it so it doesn't interfere with line matching.
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
searchLines = searchLines[:len(searchLines)-1]
}
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
// Pass 3 trim all leading and trailing whitespace (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
return content, false
}
// seekLines scans contentLines looking for a contiguous subsequence that matches
// searchLines according to the provided `eq` function. It returns the start and
// end (exclusive) indices into contentLines of the match.
func seekLines(contentLines, searchLines []string, eq func(a, b string) bool) (start, end int, ok bool) {
if len(searchLines) == 0 {
return 0, 0, true
}
if len(searchLines) > len(contentLines) {
return 0, 0, false
}
outer:
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
for j, sLine := range searchLines {
if !eq(contentLines[i+j], sLine) {
continue outer
}
}
return i, i + len(searchLines), true
}
return 0, 0, false
}
// spliceLines replaces contentLines[start:end] with replacement text, returning
// the full content as a single string.
func spliceLines(contentLines []string, start, end int, replacement string) string {
var b strings.Builder
for _, l := range contentLines[:start] {
_, _ = b.WriteString(l)
}
_, _ = b.WriteString(replacement)
for _, l := range contentLines[end:] {
_, _ = b.WriteString(l)
}
return b.String()
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
+285
View File
@@ -649,6 +649,106 @@ func TestEditFiles(t *testing.T) {
filepath.Join(tmpdir, "file3"): "edited3 3",
},
},
{
name: "TrailingWhitespace",
contents: map[string]string{filepath.Join(tmpdir, "trailing-ws"): "foo \nbar\t\t\nbaz"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "trailing-ws"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo\nbar\nbaz",
Replace: "replaced",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "trailing-ws"): "replaced"},
},
{
name: "TabsVsSpaces",
contents: map[string]string{filepath.Join(tmpdir, "tabs-vs-spaces"): "\tif true {\n\t\tfoo()\n\t}"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "tabs-vs-spaces"),
Edits: []workspacesdk.FileEdit{
{
// Search uses spaces but file uses tabs.
Search: " if true {\n foo()\n }",
Replace: "\tif true {\n\t\tbar()\n\t}",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "tabs-vs-spaces"): "\tif true {\n\t\tbar()\n\t}"},
},
{
name: "DifferentIndentDepth",
contents: map[string]string{filepath.Join(tmpdir, "indent-depth"): "\t\t\tdeep()\n\t\t\tnested()"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "indent-depth"),
Edits: []workspacesdk.FileEdit{
{
// Search has wrong indent depth (1 tab instead of 3).
Search: "\tdeep()\n\tnested()",
Replace: "\t\t\tdeep()\n\t\t\tchanged()",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "indent-depth"): "\t\t\tdeep()\n\t\t\tchanged()"},
},
{
name: "ExactMatchPreferred",
contents: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "hello world"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "exact-preferred"),
Edits: []workspacesdk.FileEdit{
{
Search: "hello world",
Replace: "goodbye world",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
},
{
name: "NoMatchStillSucceeds",
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "no-match"),
Edits: []workspacesdk.FileEdit{
{
Search: "this does not exist in the file",
Replace: "whatever",
},
},
},
},
// File should remain unchanged.
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
},
{
name: "MixedWhitespaceMultiline",
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "mixed-ws"),
Edits: []workspacesdk.FileEdit{
{
// Search uses spaces, file uses tabs.
Search: " result := compute()\n fmt.Println(result)\n",
Replace: "\tresult := compute()\n\tlog.Println(result)\n",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tlog.Println(result)\n}"},
},
{
name: "MultiError",
contents: map[string]string{
@@ -737,3 +837,188 @@ func TestEditFiles(t *testing.T) {
})
}
}
func TestReadFileLines(t *testing.T) {
t.Parallel()
tmpdir := os.TempDir()
noPermsFilePath := filepath.Join(tmpdir, "no-perms-lines")
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
fs := newTestFs(afero.NewMemMapFs(), func(call, file string) error {
if file == noPermsFilePath {
return os.ErrPermission
}
return nil
})
api := agentfiles.NewAPI(logger, fs)
dirPath := filepath.Join(tmpdir, "a-directory-lines")
err := fs.MkdirAll(dirPath, 0o755)
require.NoError(t, err)
emptyFilePath := filepath.Join(tmpdir, "empty-file")
err = afero.WriteFile(fs, emptyFilePath, []byte(""), 0o644)
require.NoError(t, err)
basicFilePath := filepath.Join(tmpdir, "basic-file")
err = afero.WriteFile(fs, basicFilePath, []byte("line1\nline2\nline3"), 0o644)
require.NoError(t, err)
longLine := string(bytes.Repeat([]byte("x"), 1025))
longLineFilePath := filepath.Join(tmpdir, "long-line-file")
err = afero.WriteFile(fs, longLineFilePath, []byte(longLine), 0o644)
require.NoError(t, err)
largeFilePath := filepath.Join(tmpdir, "large-file")
err = afero.WriteFile(fs, largeFilePath, bytes.Repeat([]byte("x"), 1<<20+1), 0o644)
require.NoError(t, err)
tests := []struct {
name string
path string
offset int64
limit int64
expSuccess bool
expError string
expContent string
expTotal int
expRead int
expSize int64
// useCodersdk is set for cases where the handler returns
// codersdk.Response (query param validation) instead of ReadFileLinesResponse.
useCodersdk bool
}{
{
name: "NoPath",
path: "",
useCodersdk: true,
expError: "is required",
},
{
name: "RelativePath",
path: "relative/path",
expError: "file path must be absolute",
},
{
name: "NonExistent",
path: filepath.Join(tmpdir, "does-not-exist"),
expError: "file does not exist",
},
{
name: "IsDir",
path: dirPath,
expError: "not a file",
},
{
name: "NoPermissions",
path: noPermsFilePath,
expError: "permission denied",
},
{
name: "EmptyFile",
path: emptyFilePath,
expSuccess: true,
expTotal: 0,
expRead: 0,
expSize: 0,
},
{
name: "BasicRead",
path: basicFilePath,
expSuccess: true,
expContent: "1\tline1\n2\tline2\n3\tline3",
expTotal: 3,
expRead: 3,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "Offset2",
path: basicFilePath,
offset: 2,
expSuccess: true,
expContent: "2\tline2\n3\tline3",
expTotal: 3,
expRead: 2,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "Limit1",
path: basicFilePath,
limit: 1,
expSuccess: true,
expContent: "1\tline1",
expTotal: 3,
expRead: 1,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "Offset2Limit1",
path: basicFilePath,
offset: 2,
limit: 1,
expSuccess: true,
expContent: "2\tline2",
expTotal: 3,
expRead: 1,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "OffsetBeyondFile",
path: basicFilePath,
offset: 100,
expError: "offset 100 is beyond the file length of 3 lines",
},
{
name: "LongLineTruncation",
path: longLineFilePath,
expSuccess: true,
expContent: "1\t" + string(bytes.Repeat([]byte("x"), 1024)) + "... [truncated]",
expTotal: 1,
expRead: 1,
expSize: 1025,
},
{
name: "LargeFile",
path: largeFilePath,
expError: "exceeds the maximum",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/read-file-lines?path=%s&offset=%d&limit=%d", tt.path, tt.offset, tt.limit), nil)
api.Routes().ServeHTTP(w, r)
if tt.useCodersdk {
// Query param validation errors return codersdk.Response.
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), tt.expError)
return
}
var resp agentfiles.ReadFileLinesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
if tt.expSuccess {
require.Equal(t, http.StatusOK, w.Code)
require.True(t, resp.Success)
require.Equal(t, tt.expContent, resp.Content)
require.Equal(t, tt.expTotal, resp.TotalLines)
require.Equal(t, tt.expRead, resp.LinesRead)
require.Equal(t, tt.expSize, resp.FileSize)
} else {
require.Equal(t, http.StatusOK, w.Code)
require.False(t, resp.Success)
require.Contains(t, resp.Error, tt.expError)
}
})
}
}
+175
View File
@@ -0,0 +1,175 @@
package agentproc
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// API exposes process-related operations through the agent.
type API struct {
logger slog.Logger
manager *manager
}
// NewAPI creates a new process API handler.
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *API {
return &API{
logger: logger,
manager: newManager(logger, execer, updateEnv),
}
}
// Close shuts down the process manager, killing all running
// processes.
func (api *API) Close() error {
return api.manager.Close()
}
// Routes returns the HTTP handler for process-related routes.
func (api *API) Routes() http.Handler {
r := chi.NewRouter()
r.Post("/start", api.handleStartProcess)
r.Get("/list", api.handleListProcesses)
r.Get("/{id}/output", api.handleProcessOutput)
r.Post("/{id}/signal", api.handleSignalProcess)
return r
}
// handleStartProcess starts a new process.
func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req workspacesdk.StartProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Request body must be valid JSON.",
Detail: err.Error(),
})
return
}
if req.Command == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Command is required.",
})
return
}
proc, err := api.manager.start(req)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start process.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.StartProcessResponse{
ID: proc.id,
Started: true,
})
}
// handleListProcesses lists all tracked processes.
func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
infos := api.manager.list()
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListProcessesResponse{
Processes: infos,
})
}
// handleProcessOutput returns the output of a process.
func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := chi.URLParam(r, "id")
proc, ok := api.manager.get(id)
if !ok {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
output, truncated := proc.output()
info := proc.info()
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ProcessOutputResponse{
Output: output,
Truncated: truncated,
Running: info.Running,
ExitCode: info.ExitCode,
})
}
// handleSignalProcess sends a signal to a running process.
func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := chi.URLParam(r, "id")
var req workspacesdk.SignalProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Request body must be valid JSON.",
Detail: err.Error(),
})
return
}
if req.Signal == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Signal is required.",
})
return
}
if req.Signal != "kill" && req.Signal != "terminate" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf(
"Unsupported signal %q. Use \"kill\" or \"terminate\".",
req.Signal,
),
})
return
}
if err := api.manager.signal(id, req.Signal); err != nil {
switch {
case errors.Is(err, errProcessNotFound):
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
case errors.Is(err, errProcessNotRunning):
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: fmt.Sprintf(
"Process %q is not running.", id,
),
})
default:
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to signal process.",
Detail: err.Error(),
})
}
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: fmt.Sprintf(
"Signal %q sent to process %q.", req.Signal, id,
),
})
}
+691
View File
@@ -0,0 +1,691 @@
package agentproc_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
// postStart sends a POST /start request and returns the recorder.
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
body, err := json.Marshal(req)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", bytes.NewReader(body))
handler.ServeHTTP(w, r)
return w
}
// getList sends a GET /list request and returns the recorder.
func getList(t *testing.T, handler http.Handler) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/list", nil)
handler.ServeHTTP(w, r)
return w
}
// getOutput sends a GET /{id}/output request and returns the
// recorder.
func getOutput(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/%s/output", id), nil)
handler.ServeHTTP(w, r)
return w
}
// postSignal sends a POST /{id}/signal request and returns
// the recorder.
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
body, err := json.Marshal(req)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("/%s/signal", id), bytes.NewReader(body))
handler.ServeHTTP(w, r)
return w
}
// newTestAPI creates a new API with a test logger and default
// execer, returning the handler and API.
func newTestAPI(t *testing.T) http.Handler {
t.Helper()
return newTestAPIWithUpdateEnv(t, nil)
}
// newTestAPIWithUpdateEnv creates a new API with an optional
// updateEnv hook for testing environment injection.
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv)
t.Cleanup(func() {
_ = api.Close()
})
return api.Routes()
}
// waitForExit polls the output endpoint until the process is
// no longer running or the context expires.
func waitForExit(t *testing.T, handler http.Handler, id string) workspacesdk.ProcessOutputResponse {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
t.Fatal("timed out waiting for process to exit")
case <-ticker.C:
w := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
if !resp.Running {
return resp
}
}
}
}
// startAndGetID is a helper that starts a process and returns
// the process ID.
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) string {
t.Helper()
w := postStart(t, handler, req)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Started)
require.NotEmpty(t, resp.ID)
return resp.ID
}
func TestStartProcess(t *testing.T) {
t.Parallel()
t.Run("ForegroundCommand", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postStart(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello",
})
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Started)
require.NotEmpty(t, resp.ID)
})
t.Run("BackgroundCommand", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postStart(t, handler, workspacesdk.StartProcessRequest{
Command: "echo background",
Background: true,
})
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Started)
require.NotEmpty(t, resp.ID)
})
t.Run("EmptyCommand", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postStart(t, handler, workspacesdk.StartProcessRequest{
Command: "",
})
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "Command is required")
})
t.Run("MalformedJSON", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", strings.NewReader("{invalid json"))
handler.ServeHTTP(w, r)
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "valid JSON")
})
t.Run("CustomWorkDir", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
tmpDir := t.TempDir()
// Write a marker file to verify the command ran in
// the correct directory. Comparing pwd output is
// unreliable on Windows where Git Bash returns POSIX
// paths.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "touch marker.txt && ls marker.txt",
WorkDir: tmpDir,
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("CustomEnv", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Use a unique env var name to avoid collisions in
// parallel tests.
envKey := fmt.Sprintf("TEST_PROC_ENV_%d", time.Now().UnixNano())
envVal := "custom_value_12345"
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("printenv %s", envKey),
Env: map[string]string{envKey: envVal},
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, strings.TrimSpace(resp.Output), envVal)
})
t.Run("UpdateEnvHook", func(t *testing.T) {
t.Parallel()
envKey := fmt.Sprintf("TEST_UPDATE_ENV_%d", time.Now().UnixNano())
envVal := "injected_by_hook"
handler := newTestAPIWithUpdateEnv(t, func(current []string) ([]string, error) {
return append(current, fmt.Sprintf("%s=%s", envKey, envVal)), nil
})
// The process should see the variable even though it
// was not passed in req.Env.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("printenv %s", envKey),
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, strings.TrimSpace(resp.Output), envVal)
})
t.Run("UpdateEnvHookOverriddenByReqEnv", func(t *testing.T) {
t.Parallel()
envKey := fmt.Sprintf("TEST_OVERRIDE_%d", time.Now().UnixNano())
hookVal := "from_hook"
reqVal := "from_request"
handler := newTestAPIWithUpdateEnv(t, func(current []string) ([]string, error) {
return append(current, fmt.Sprintf("%s=%s", envKey, hookVal)), nil
})
// req.Env should take precedence over the hook.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("printenv %s", envKey),
Env: map[string]string{envKey: reqVal},
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
// When duplicate env vars exist, shells use the last
// value. Since req.Env is appended after the hook,
// the request value wins.
require.Contains(t, strings.TrimSpace(resp.Output), reqVal)
})
}
func TestListProcesses(t *testing.T) {
t.Parallel()
t.Run("NoProcesses", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.NotNil(t, resp.Processes)
require.Empty(t, resp.Processes)
})
t.Run("MixedRunningAndExited", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a process that exits quickly.
exitedID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
waitForExit(t, handler, exitedID)
// Start a long-running process.
runningID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
// List should contain both.
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 2)
procMap := make(map[string]workspacesdk.ProcessInfo)
for _, p := range resp.Processes {
procMap[p.ID] = p
}
exited, ok := procMap[exitedID]
require.True(t, ok, "exited process should be in list")
require.False(t, exited.Running)
require.NotNil(t, exited.ExitCode)
running, ok := procMap[runningID]
require.True(t, ok, "running process should be in list")
require.True(t, running.Running)
// Clean up the long-running process.
sw := postSignal(t, handler, runningID, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusOK, sw.Code)
})
}
func TestProcessOutput(t *testing.T) {
t.Parallel()
t.Run("ExitedProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello-output",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "hello-output")
})
t.Run("RunningProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Running)
// Kill and wait for the process so cleanup does
// not hang.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
waitForExit(t, handler, id)
})
t.Run("NonexistentProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := getOutput(t, handler, "nonexistent-id-12345")
require.Equal(t, http.StatusNotFound, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "not found")
})
}
func TestSignalProcess(t *testing.T) {
t.Parallel()
t.Run("KillRunning", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusOK, w.Code)
// Verify the process exits.
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
})
t.Run("TerminateRunning", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("SIGTERM is not supported on Windows")
}
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "terminate",
})
require.Equal(t, http.StatusOK, w.Code)
// Verify the process exits.
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
})
t.Run("NonexistentProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postSignal(t, handler, "nonexistent-id-12345", workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("AlreadyExitedProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
// Wait for exit first.
waitForExit(t, handler, id)
// Signaling an exited process should return 409
// Conflict via the errProcessNotRunning sentinel.
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
assert.Equal(t, http.StatusConflict, w.Code,
"expected 409 for signaling exited process, got %d", w.Code)
})
t.Run("EmptySignal", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "",
})
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "Signal is required")
// Clean up.
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
})
t.Run("InvalidSignal", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "SIGFOO",
})
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "Unsupported signal")
// Clean up.
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
})
}
func TestProcessLifecycle(t *testing.T) {
t.Parallel()
t.Run("StartWaitCheckOutput", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo lifecycle-test && echo second-line",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "lifecycle-test")
require.Contains(t, resp.Output, "second-line")
})
t.Run("NonZeroExitCode", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "exit 42",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 42, *resp.ExitCode)
})
t.Run("StartSignalVerifyExit", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a long-running background process.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
// Verify it's running.
w := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var running workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&running)
require.NoError(t, err)
require.True(t, running.Running)
// Signal it.
sw := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusOK, sw.Code)
// Verify it exits.
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
})
t.Run("OutputExceedsBuffer", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Generate output that exceeds MaxHeadBytes +
// MaxTailBytes. Each line is ~100 chars, and we
// need more than 32KB total (16KB head + 16KB
// tail).
lineCount := (agentproc.MaxHeadBytes+agentproc.MaxTailBytes)/50 + 500
cmd := fmt.Sprintf(
"for i in $(seq 1 %d); do echo \"line-$i-padding-to-make-this-longer-than-fifty-characters-total\"; done",
lineCount,
)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: cmd,
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
// The output should be truncated with head/tail
// strategy metadata.
require.NotNil(t, resp.Truncated, "large output should be truncated")
require.Equal(t, "head_tail", resp.Truncated.Strategy)
require.Greater(t, resp.Truncated.OmittedBytes, 0)
require.Greater(t, resp.Truncated.OriginalBytes, resp.Truncated.RetainedBytes)
// Verify the output contains the omission marker.
require.Contains(t, resp.Output, "... [omitted")
})
t.Run("StderrCaptured", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo stdout-msg && echo stderr-msg >&2",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
// Both stdout and stderr should be captured.
require.Contains(t, resp.Output, "stdout-msg")
require.Contains(t, resp.Output, "stderr-msg")
})
}
+309
View File
@@ -0,0 +1,309 @@
package agentproc
import (
"fmt"
"strings"
"sync"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// MaxHeadBytes is the number of bytes retained from the
// beginning of the output for LLM consumption.
MaxHeadBytes = 16 << 10 // 16KB
// MaxTailBytes is the number of bytes retained from the
// end of the output for LLM consumption.
MaxTailBytes = 16 << 10 // 16KB
// MaxLineLength is the maximum length of a single line
// before it is truncated. This prevents minified files
// or other long single-line output from consuming the
// entire buffer.
MaxLineLength = 2048
// lineTruncationSuffix is appended to lines that exceed
// MaxLineLength.
lineTruncationSuffix = " ... [truncated]"
)
// HeadTailBuffer is a thread-safe buffer that captures process
// output and provides head+tail truncation for LLM consumption.
// It implements io.Writer so it can be used directly as
// cmd.Stdout or cmd.Stderr.
//
// The buffer stores up to MaxHeadBytes from the beginning of
// the output and up to MaxTailBytes from the end in a ring
// buffer, keeping total memory usage bounded regardless of
// how much output is written.
type HeadTailBuffer struct {
mu sync.Mutex
head []byte
tail []byte
tailPos int
tailFull bool
headFull bool
totalBytes int
maxHead int
maxTail int
}
// NewHeadTailBuffer creates a new HeadTailBuffer with the
// default head and tail sizes.
func NewHeadTailBuffer() *HeadTailBuffer {
return &HeadTailBuffer{
maxHead: MaxHeadBytes,
maxTail: MaxTailBytes,
}
}
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
// head and tail sizes. This is useful for testing truncation
// logic with smaller buffers.
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
return &HeadTailBuffer{
maxHead: maxHead,
maxTail: maxTail,
}
}
// Write implements io.Writer. It is safe for concurrent use.
// All bytes are accepted; the return value always equals
// len(p) with a nil error.
func (b *HeadTailBuffer) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
b.mu.Lock()
defer b.mu.Unlock()
n := len(p)
b.totalBytes += n
// Fill head buffer if it is not yet full.
if !b.headFull {
remaining := b.maxHead - len(b.head)
if remaining > 0 {
take := remaining
if take > len(p) {
take = len(p)
}
b.head = append(b.head, p[:take]...)
p = p[take:]
if len(b.head) >= b.maxHead {
b.headFull = true
}
}
if len(p) == 0 {
return n, nil
}
}
// Write remaining bytes into the tail ring buffer.
b.writeTail(p)
return n, nil
}
// writeTail appends data to the tail ring buffer. The caller
// must hold b.mu.
func (b *HeadTailBuffer) writeTail(p []byte) {
if b.maxTail <= 0 {
return
}
// Lazily allocate the tail buffer on first use.
if b.tail == nil {
b.tail = make([]byte, b.maxTail)
}
for len(p) > 0 {
// Write as many bytes as fit starting at tailPos.
space := b.maxTail - b.tailPos
take := space
if take > len(p) {
take = len(p)
}
copy(b.tail[b.tailPos:b.tailPos+take], p[:take])
p = p[take:]
b.tailPos += take
if b.tailPos >= b.maxTail {
b.tailPos = 0
b.tailFull = true
}
}
}
// tailBytes returns the current tail contents in order. The
// caller must hold b.mu.
func (b *HeadTailBuffer) tailBytes() []byte {
if b.tail == nil {
return nil
}
if !b.tailFull {
// Haven't wrapped yet; data is [0, tailPos).
return b.tail[:b.tailPos]
}
// Wrapped: data is [tailPos, maxTail) + [0, tailPos).
out := make([]byte, b.maxTail)
n := copy(out, b.tail[b.tailPos:])
copy(out[n:], b.tail[:b.tailPos])
return out
}
// Bytes returns a copy of the raw buffer contents. If no
// truncation has occurred the full output is returned;
// otherwise the head and tail portions are concatenated.
func (b *HeadTailBuffer) Bytes() []byte {
b.mu.Lock()
defer b.mu.Unlock()
tail := b.tailBytes()
if len(tail) == 0 {
out := make([]byte, len(b.head))
copy(out, b.head)
return out
}
out := make([]byte, len(b.head)+len(tail))
copy(out, b.head)
copy(out[len(b.head):], tail)
return out
}
// Len returns the number of bytes currently stored in the
// buffer.
func (b *HeadTailBuffer) Len() int {
b.mu.Lock()
defer b.mu.Unlock()
tailLen := 0
if b.tailFull {
tailLen = b.maxTail
} else if b.tail != nil {
tailLen = b.tailPos
}
return len(b.head) + tailLen
}
// TotalWritten returns the total number of bytes written to
// the buffer, which may exceed the stored capacity.
func (b *HeadTailBuffer) TotalWritten() int {
b.mu.Lock()
defer b.mu.Unlock()
return b.totalBytes
}
// Output returns the truncated output suitable for LLM
// consumption, along with truncation metadata. If the total
// output fits within the head buffer alone, the full output is
// returned with nil truncation info. Otherwise the head and
// tail are joined with an omission marker and long lines are
// truncated.
func (b *HeadTailBuffer) Output() (string, *workspacesdk.ProcessTruncation) {
b.mu.Lock()
head := make([]byte, len(b.head))
copy(head, b.head)
tail := b.tailBytes()
total := b.totalBytes
headFull := b.headFull
b.mu.Unlock()
storedLen := len(head) + len(tail)
// If everything fits, no head/tail split is needed.
if !headFull || len(tail) == 0 {
out := truncateLines(string(head))
if total == 0 {
return "", nil
}
return out, nil
}
// We have both head and tail data, meaning the total
// output exceeded the head capacity. Build the
// combined output with an omission marker.
omitted := total - storedLen
headStr := truncateLines(string(head))
tailStr := truncateLines(string(tail))
var sb strings.Builder
_, _ = sb.WriteString(headStr)
if omitted > 0 {
_, _ = sb.WriteString(fmt.Sprintf(
"\n\n... [omitted %d bytes] ...\n\n",
omitted,
))
} else {
// Head and tail are contiguous but were stored
// separately because the head filled up.
_, _ = sb.WriteString("\n")
}
_, _ = sb.WriteString(tailStr)
result := sb.String()
return result, &workspacesdk.ProcessTruncation{
OriginalBytes: total,
RetainedBytes: len(result),
OmittedBytes: omitted,
Strategy: "head_tail",
}
}
// truncateLines scans the input line by line and truncates
// any line longer than MaxLineLength.
func truncateLines(s string) string {
if len(s) <= MaxLineLength {
// Fast path: if the entire string is shorter than
// the max line length, no line can exceed it.
return s
}
var b strings.Builder
b.Grow(len(s))
for len(s) > 0 {
idx := strings.IndexByte(s, '\n')
var line string
if idx == -1 {
line = s
s = ""
} else {
line = s[:idx]
s = s[idx+1:]
}
if len(line) > MaxLineLength {
// Truncate preserving the suffix length so the
// total does not exceed a reasonable size.
cut := MaxLineLength - len(lineTruncationSuffix)
if cut < 0 {
cut = 0
}
_, _ = b.WriteString(line[:cut])
_, _ = b.WriteString(lineTruncationSuffix)
} else {
_, _ = b.WriteString(line)
}
// Re-add the newline unless this was the final
// segment without a trailing newline.
if idx != -1 {
_ = b.WriteByte('\n')
}
}
return b.String()
}
// Reset clears the buffer, discarding all data.
func (b *HeadTailBuffer) Reset() {
b.mu.Lock()
defer b.mu.Unlock()
b.head = nil
b.tail = nil
b.tailPos = 0
b.tailFull = false
b.headFull = false
b.totalBytes = 0
}
+338
View File
@@ -0,0 +1,338 @@
package agentproc_test
import (
"fmt"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentproc"
)
func TestHeadTailBuffer_EmptyBuffer(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
out, info := buf.Output()
require.Empty(t, out)
require.Nil(t, info)
require.Equal(t, 0, buf.Len())
require.Equal(t, 0, buf.TotalWritten())
require.Empty(t, buf.Bytes())
}
func TestHeadTailBuffer_SmallOutput(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
data := "hello world\n"
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, len(data), n)
out, info := buf.Output()
require.Equal(t, data, out)
require.Nil(t, info, "small output should not be truncated")
require.Equal(t, len(data), buf.Len())
require.Equal(t, len(data), buf.TotalWritten())
}
func TestHeadTailBuffer_ExactlyHeadSize(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Build data that is exactly MaxHeadBytes using short
// lines so that line truncation does not apply.
line := strings.Repeat("x", 79) + "\n" // 80 bytes per line
count := agentproc.MaxHeadBytes / len(line)
pad := agentproc.MaxHeadBytes - (count * len(line))
data := strings.Repeat(line, count) + strings.Repeat("y", pad)
require.Equal(t, agentproc.MaxHeadBytes, len(data),
"test data must be exactly MaxHeadBytes")
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, agentproc.MaxHeadBytes, n)
out, info := buf.Output()
require.Equal(t, data, out)
require.Nil(t, info, "output fitting in head should not be truncated")
require.Equal(t, agentproc.MaxHeadBytes, buf.Len())
}
func TestHeadTailBuffer_HeadPlusTailNoOmission(t *testing.T) {
t.Parallel()
// Use a small buffer so we can test the boundary where
// head fills and tail starts but nothing is omitted.
// With maxHead=10, maxTail=10, writing exactly 20 bytes
// means head gets 10, tail gets 10, omitted = 0.
buf := agentproc.NewHeadTailBufferSized(10, 10)
data := "0123456789abcdefghij" // 20 bytes
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, 20, n)
out, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, 0, info.OmittedBytes)
require.Equal(t, "head_tail", info.Strategy)
// The output should contain both head and tail.
require.Contains(t, out, "0123456789")
require.Contains(t, out, "abcdefghij")
}
func TestHeadTailBuffer_LargeOutputTruncation(t *testing.T) {
t.Parallel()
// Use small head/tail so truncation is easy to verify.
buf := agentproc.NewHeadTailBufferSized(10, 10)
// Write 100 bytes: head=10, tail=10, omitted=80.
data := strings.Repeat("A", 50) + strings.Repeat("Z", 50)
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, 100, n)
out, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, 100, info.OriginalBytes)
require.Equal(t, 80, info.OmittedBytes)
require.Equal(t, "head_tail", info.Strategy)
// Head should be first 10 bytes (all A's).
require.True(t, strings.HasPrefix(out, "AAAAAAAAAA"))
// Tail should be last 10 bytes (all Z's).
require.True(t, strings.HasSuffix(out, "ZZZZZZZZZZ"))
// Omission marker should be present.
require.Contains(t, out, "... [omitted 80 bytes] ...")
require.Equal(t, 20, buf.Len())
require.Equal(t, 100, buf.TotalWritten())
}
func TestHeadTailBuffer_MultiMBStaysBounded(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Write 5MB of data in chunks.
chunk := []byte(strings.Repeat("x", 4096) + "\n")
totalWritten := 0
for totalWritten < 5*1024*1024 {
n, err := buf.Write(chunk)
require.NoError(t, err)
require.Equal(t, len(chunk), n)
totalWritten += n
}
// Memory should be bounded to head+tail.
require.LessOrEqual(t, buf.Len(),
agentproc.MaxHeadBytes+agentproc.MaxTailBytes)
require.Equal(t, totalWritten, buf.TotalWritten())
out, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, totalWritten, info.OriginalBytes)
require.Greater(t, info.OmittedBytes, 0)
require.NotEmpty(t, out)
}
func TestHeadTailBuffer_LongLineTruncation(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Write a line longer than MaxLineLength.
longLine := strings.Repeat("m", agentproc.MaxLineLength+500)
_, err := buf.Write([]byte(longLine + "\n"))
require.NoError(t, err)
out, _ := buf.Output()
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
require.Len(t, lines, 1)
require.LessOrEqual(t, len(lines[0]), agentproc.MaxLineLength)
require.True(t, strings.HasSuffix(lines[0], "... [truncated]"))
}
func TestHeadTailBuffer_LongLineInTail(t *testing.T) {
t.Parallel()
// Use small buffers so we can force data into the tail.
buf := agentproc.NewHeadTailBufferSized(20, 5000)
// Fill head with short data.
_, err := buf.Write([]byte("head data goes here\n"))
require.NoError(t, err)
// Now write a very long line into the tail.
longLine := strings.Repeat("T", agentproc.MaxLineLength+100)
_, err = buf.Write([]byte(longLine + "\n"))
require.NoError(t, err)
out, info := buf.Output()
require.NotNil(t, info)
// The long line in the tail should be truncated.
require.Contains(t, out, "... [truncated]")
}
func TestHeadTailBuffer_ConcurrentWrites(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
const goroutines = 10
const writes = 1000
var wg sync.WaitGroup
wg.Add(goroutines)
for g := range goroutines {
go func() {
defer wg.Done()
line := fmt.Sprintf("goroutine-%d: data\n", g)
for range writes {
_, err := buf.Write([]byte(line))
assert.NoError(t, err)
}
}()
}
wg.Wait()
// Verify totals are consistent.
require.Greater(t, buf.TotalWritten(), 0)
require.Greater(t, buf.Len(), 0)
out, _ := buf.Output()
require.NotEmpty(t, out)
}
func TestHeadTailBuffer_TruncationInfoFields(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBufferSized(10, 10)
// Write enough to cause omission.
data := strings.Repeat("D", 50)
_, err := buf.Write([]byte(data))
require.NoError(t, err)
_, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, 50, info.OriginalBytes)
require.Equal(t, 30, info.OmittedBytes)
require.Equal(t, "head_tail", info.Strategy)
// RetainedBytes is the length of the formatted output
// string including the omission marker.
require.Greater(t, info.RetainedBytes, 0)
}
func TestHeadTailBuffer_MultipleSmallWrites(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Write one byte at a time.
expected := "hello world"
for i := range len(expected) {
n, err := buf.Write([]byte{expected[i]})
require.NoError(t, err)
require.Equal(t, 1, n)
}
out, info := buf.Output()
require.Equal(t, expected, out)
require.Nil(t, info)
}
func TestHeadTailBuffer_WriteEmptySlice(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
n, err := buf.Write([]byte{})
require.NoError(t, err)
require.Equal(t, 0, n)
require.Equal(t, 0, buf.TotalWritten())
}
func TestHeadTailBuffer_Reset(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
_, err := buf.Write([]byte("some data"))
require.NoError(t, err)
require.Greater(t, buf.Len(), 0)
buf.Reset()
require.Equal(t, 0, buf.Len())
require.Equal(t, 0, buf.TotalWritten())
out, info := buf.Output()
require.Empty(t, out)
require.Nil(t, info)
}
func TestHeadTailBuffer_BytesReturnsCopy(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
_, err := buf.Write([]byte("original"))
require.NoError(t, err)
b := buf.Bytes()
require.Equal(t, []byte("original"), b)
// Mutating the returned slice should not affect the
// buffer.
b[0] = 'X'
require.Equal(t, []byte("original"), buf.Bytes())
}
func TestHeadTailBuffer_RingBufferWraparound(t *testing.T) {
t.Parallel()
// Use a tail of 10 bytes and write enough to wrap
// around multiple times.
buf := agentproc.NewHeadTailBufferSized(5, 10)
// Fill head (5 bytes).
_, err := buf.Write([]byte("HEADD"))
require.NoError(t, err)
// Write 25 bytes into tail, wrapping 2.5 times.
_, err = buf.Write([]byte("0123456789"))
require.NoError(t, err)
_, err = buf.Write([]byte("abcdefghij"))
require.NoError(t, err)
_, err = buf.Write([]byte("ABCDE"))
require.NoError(t, err)
out, info := buf.Output()
require.NotNil(t, info)
// Tail should contain the last 10 bytes: "fghijABCDE".
require.True(t, strings.HasSuffix(out, "fghijABCDE"),
"expected tail to be last 10 bytes, got: %q", out)
}
func TestHeadTailBuffer_MultipleLinesTruncated(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
short := "short line\n"
long := strings.Repeat("L", agentproc.MaxLineLength+100) + "\n"
_, err := buf.Write([]byte(short + long + short))
require.NoError(t, err)
out, _ := buf.Output()
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
require.Len(t, lines, 3)
require.Equal(t, "short line", lines[0])
require.True(t, strings.HasSuffix(lines[1], "... [truncated]"))
require.Equal(t, "short line", lines[2])
}
+294
View File
@@ -0,0 +1,294 @@
package agentproc
import (
"context"
"fmt"
"os"
"os/exec"
"sync"
"syscall"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
var (
errProcessNotFound = xerrors.New("process not found")
errProcessNotRunning = xerrors.New("process is not running")
)
// process represents a running or completed process.
type process struct {
mu sync.Mutex
id string
command string
workDir string
background bool
cmd *exec.Cmd
cancel context.CancelFunc
buf *HeadTailBuffer
running bool
exitCode *int
startedAt int64
exitedAt *int64
done chan struct{} // closed when process exits
}
// info returns a snapshot of the process state.
func (p *process) info() workspacesdk.ProcessInfo {
p.mu.Lock()
defer p.mu.Unlock()
return workspacesdk.ProcessInfo{
ID: p.id,
Command: p.command,
WorkDir: p.workDir,
Background: p.background,
Running: p.running,
ExitCode: p.exitCode,
StartedAt: p.startedAt,
ExitedAt: p.exitedAt,
}
}
// output returns the truncated output from the process buffer
// along with optional truncation metadata.
func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
return p.buf.Output()
}
// manager tracks processes spawned by the agent.
type manager struct {
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
}
// newManager creates a new process manager.
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
return &manager{
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
}
}
// start spawns a new process. Both foreground and background
// processes use a long-lived context so the process survives
// the HTTP request lifecycle. The background flag only affects
// client-side polling behavior.
func (m *manager) start(req workspacesdk.StartProcessRequest) (*process, error) {
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return nil, xerrors.New("manager is closed")
}
m.mu.Unlock()
id := uuid.New().String()
// Use a cancellable context so Close() can terminate
// all processes. context.Background() is the parent so
// the process is not tied to any HTTP request.
ctx, cancel := context.WithCancel(context.Background())
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
if req.WorkDir != "" {
cmd.Dir = req.WorkDir
}
cmd.Stdin = nil
// WaitDelay ensures cmd.Wait returns promptly after
// the process is killed, even if child processes are
// still holding the stdout/stderr pipes open.
cmd.WaitDelay = 5 * time.Second
buf := NewHeadTailBuffer()
cmd.Stdout = buf
cmd.Stderr = buf
// Build the process environment. If the manager has an
// updateEnv hook (provided by the agent), use it to get the
// full agent environment including GIT_ASKPASS, CODER_* vars,
// etc. Otherwise fall back to the current process env.
baseEnv := os.Environ()
if m.updateEnv != nil {
updated, err := m.updateEnv(baseEnv)
if err != nil {
m.logger.Warn(
context.Background(),
"failed to update command environment, falling back to os env",
slog.Error(err),
)
} else {
baseEnv = updated
}
}
// Always set cmd.Env explicitly so that req.Env overrides
// are applied on top of the full agent environment.
cmd.Env = baseEnv
for k, v := range req.Env {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
}
if err := cmd.Start(); err != nil {
cancel()
return nil, xerrors.Errorf("start process: %w", err)
}
now := m.clock.Now().Unix()
proc := &process{
id: id,
command: req.Command,
workDir: req.WorkDir,
background: req.Background,
cmd: cmd,
cancel: cancel,
buf: buf,
running: true,
startedAt: now,
done: make(chan struct{}),
}
m.mu.Lock()
if m.closed {
m.mu.Unlock()
// Manager closed between our check and now. Kill the
// process we just started.
cancel()
_ = cmd.Wait()
return nil, xerrors.New("manager is closed")
}
m.procs[id] = proc
m.mu.Unlock()
go func() {
err := cmd.Wait()
exitedAt := m.clock.Now().Unix()
proc.mu.Lock()
proc.running = false
proc.exitedAt = &exitedAt
code := 0
if err != nil {
// Extract the exit code from the error.
var exitErr *exec.ExitError
if xerrors.As(err, &exitErr) {
code = exitErr.ExitCode()
} else {
// Unknown error; use -1 as a sentinel.
code = -1
m.logger.Warn(
context.Background(),
"process wait returned non-exit error",
slog.F("id", id),
slog.Error(err),
)
}
}
proc.exitCode = &code
proc.mu.Unlock()
close(proc.done)
}()
return proc, nil
}
// get returns a process by ID.
func (m *manager) get(id string) (*process, bool) {
m.mu.Lock()
defer m.mu.Unlock()
proc, ok := m.procs[id]
return proc, ok
}
// list returns info about all tracked processes.
func (m *manager) list() []workspacesdk.ProcessInfo {
m.mu.Lock()
defer m.mu.Unlock()
infos := make([]workspacesdk.ProcessInfo, 0, len(m.procs))
for _, proc := range m.procs {
infos = append(infos, proc.info())
}
return infos
}
// signal sends a signal to a running process. It returns
// sentinel errors errProcessNotFound and errProcessNotRunning
// so callers can distinguish failure modes.
func (m *manager) signal(id string, sig string) error {
m.mu.Lock()
proc, ok := m.procs[id]
m.mu.Unlock()
if !ok {
return errProcessNotFound
}
proc.mu.Lock()
defer proc.mu.Unlock()
if !proc.running {
return errProcessNotRunning
}
switch sig {
case "kill":
if err := proc.cmd.Process.Kill(); err != nil {
return xerrors.Errorf("kill process: %w", err)
}
case "terminate":
//nolint:revive // syscall.SIGTERM is portable enough
// for our supported platforms.
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
return xerrors.Errorf("terminate process: %w", err)
}
default:
return xerrors.Errorf("unsupported signal %q", sig)
}
return nil
}
// Close kills all running processes and prevents new ones from
// starting. It cancels each process's context, which causes
// CommandContext to kill the process and its pipe goroutines to
// drain.
func (m *manager) Close() error {
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return nil
}
m.closed = true
procs := make([]*process, 0, len(m.procs))
for _, p := range m.procs {
procs = append(procs, p)
}
m.mu.Unlock()
for _, p := range procs {
p.cancel()
}
// Wait for all processes to exit.
for _, p := range procs {
<-p.done
}
return nil
}
+1
View File
@@ -24,6 +24,7 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent
var o agent.Options
log := testutil.Logger(t).Named("agent")
o.Logger = log
o.SocketPath = testutil.AgentSocketPath(t)
for _, opt := range opts {
opt(&o)
+4
View File
@@ -235,6 +235,10 @@ type FakeAgentAPI struct {
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
}
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
panic("unimplemented")
}
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
return f.manifest, nil
}
+1
View File
@@ -28,6 +28,7 @@ func (a *agent) apiHandler() http.Handler {
})
r.Mount("/api/v0", a.filesAPI.Routes())
r.Mount("/api/v0/processes", a.processAPI.Routes())
if a.devcontainers {
r.Mount("/api/v0/containers", a.containerAPI.Routes())
+2 -1
View File
@@ -10,6 +10,7 @@ import (
"testing"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -69,7 +70,7 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
err := srv.Start()
require.NoError(t, err)
+286
View File
@@ -0,0 +1,286 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.30.0
// protoc v4.23.4
// source: agent/boundarylogproxy/codec/boundary.proto
package codec
import (
proto "github.com/coder/coder/v2/agent/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// BoundaryMessage is the envelope for all TagV2 messages sent over the
// boundary <-> agent unix socket. TagV1 carries a bare
// ReportBoundaryLogsRequest for backwards compatibility; TagV2 wraps
// everything in this envelope so the protocol can be extended with new
// message types without adding more tags.
type BoundaryMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to Msg:
//
// *BoundaryMessage_Logs
// *BoundaryMessage_Status
Msg isBoundaryMessage_Msg `protobuf_oneof:"msg"`
}
func (x *BoundaryMessage) Reset() {
*x = BoundaryMessage{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *BoundaryMessage) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*BoundaryMessage) ProtoMessage() {}
func (x *BoundaryMessage) ProtoReflect() protoreflect.Message {
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use BoundaryMessage.ProtoReflect.Descriptor instead.
func (*BoundaryMessage) Descriptor() ([]byte, []int) {
return file_agent_boundarylogproxy_codec_boundary_proto_rawDescGZIP(), []int{0}
}
func (m *BoundaryMessage) GetMsg() isBoundaryMessage_Msg {
if m != nil {
return m.Msg
}
return nil
}
func (x *BoundaryMessage) GetLogs() *proto.ReportBoundaryLogsRequest {
if x, ok := x.GetMsg().(*BoundaryMessage_Logs); ok {
return x.Logs
}
return nil
}
func (x *BoundaryMessage) GetStatus() *BoundaryStatus {
if x, ok := x.GetMsg().(*BoundaryMessage_Status); ok {
return x.Status
}
return nil
}
type isBoundaryMessage_Msg interface {
isBoundaryMessage_Msg()
}
type BoundaryMessage_Logs struct {
Logs *proto.ReportBoundaryLogsRequest `protobuf:"bytes,1,opt,name=logs,proto3,oneof"`
}
type BoundaryMessage_Status struct {
Status *BoundaryStatus `protobuf:"bytes,2,opt,name=status,proto3,oneof"`
}
func (*BoundaryMessage_Logs) isBoundaryMessage_Msg() {}
func (*BoundaryMessage_Status) isBoundaryMessage_Msg() {}
// BoundaryStatus carries operational metadata from boundary to the agent.
// The agent records these values as Prometheus metrics. This message is
// never forwarded to coderd.
type BoundaryStatus struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Logs dropped because boundary's internal channel buffer was full.
DroppedChannelFull int64 `protobuf:"varint,1,opt,name=dropped_channel_full,json=droppedChannelFull,proto3" json:"dropped_channel_full,omitempty"`
// Logs dropped because boundary's batch buffer was full after a
// failed flush attempt.
DroppedBatchFull int64 `protobuf:"varint,2,opt,name=dropped_batch_full,json=droppedBatchFull,proto3" json:"dropped_batch_full,omitempty"`
}
func (x *BoundaryStatus) Reset() {
*x = BoundaryStatus{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *BoundaryStatus) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*BoundaryStatus) ProtoMessage() {}
func (x *BoundaryStatus) ProtoReflect() protoreflect.Message {
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use BoundaryStatus.ProtoReflect.Descriptor instead.
func (*BoundaryStatus) Descriptor() ([]byte, []int) {
return file_agent_boundarylogproxy_codec_boundary_proto_rawDescGZIP(), []int{1}
}
func (x *BoundaryStatus) GetDroppedChannelFull() int64 {
if x != nil {
return x.DroppedChannelFull
}
return 0
}
func (x *BoundaryStatus) GetDroppedBatchFull() int64 {
if x != nil {
return x.DroppedBatchFull
}
return 0
}
var File_agent_boundarylogproxy_codec_boundary_proto protoreflect.FileDescriptor
var file_agent_boundarylogproxy_codec_boundary_proto_rawDesc = []byte{
0x0a, 0x2b, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79,
0x6c, 0x6f, 0x67, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2f, 0x62,
0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x1f, 0x63,
0x6f, 0x64, 0x65, 0x72, 0x2e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67,
0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2e, 0x76, 0x31, 0x1a, 0x17,
0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x61, 0x67, 0x65, 0x6e,
0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa4, 0x01, 0x0a, 0x0f, 0x42, 0x6f, 0x75, 0x6e,
0x64, 0x61, 0x72, 0x79, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x3f, 0x0a, 0x04, 0x6c,
0x6f, 0x67, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x63, 0x6f, 0x64, 0x65,
0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72,
0x74, 0x42, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x4c, 0x6f, 0x67, 0x73, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x04, 0x6c, 0x6f, 0x67, 0x73, 0x12, 0x49, 0x0a, 0x06,
0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2f, 0x2e, 0x63,
0x6f, 0x64, 0x65, 0x72, 0x2e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67,
0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2e, 0x76, 0x31, 0x2e, 0x42,
0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x48, 0x00, 0x52,
0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x42, 0x05, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, 0x70,
0x0a, 0x0e, 0x42, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73,
0x12, 0x30, 0x0a, 0x14, 0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x5f, 0x63, 0x68, 0x61, 0x6e,
0x6e, 0x65, 0x6c, 0x5f, 0x66, 0x75, 0x6c, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x12,
0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x43, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x46, 0x75,
0x6c, 0x6c, 0x12, 0x2c, 0x0a, 0x12, 0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x5f, 0x62, 0x61,
0x74, 0x63, 0x68, 0x5f, 0x66, 0x75, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x10,
0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x42, 0x61, 0x74, 0x63, 0x68, 0x46, 0x75, 0x6c, 0x6c,
0x42, 0x38, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x76, 0x32, 0x2f, 0x61, 0x67,
0x65, 0x6e, 0x74, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67, 0x70,
0x72, 0x6f, 0x78, 0x79, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
}
var (
file_agent_boundarylogproxy_codec_boundary_proto_rawDescOnce sync.Once
file_agent_boundarylogproxy_codec_boundary_proto_rawDescData = file_agent_boundarylogproxy_codec_boundary_proto_rawDesc
)
func file_agent_boundarylogproxy_codec_boundary_proto_rawDescGZIP() []byte {
file_agent_boundarylogproxy_codec_boundary_proto_rawDescOnce.Do(func() {
file_agent_boundarylogproxy_codec_boundary_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_boundarylogproxy_codec_boundary_proto_rawDescData)
})
return file_agent_boundarylogproxy_codec_boundary_proto_rawDescData
}
var file_agent_boundarylogproxy_codec_boundary_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_agent_boundarylogproxy_codec_boundary_proto_goTypes = []interface{}{
(*BoundaryMessage)(nil), // 0: coder.boundarylogproxy.codec.v1.BoundaryMessage
(*BoundaryStatus)(nil), // 1: coder.boundarylogproxy.codec.v1.BoundaryStatus
(*proto.ReportBoundaryLogsRequest)(nil), // 2: coder.agent.v2.ReportBoundaryLogsRequest
}
var file_agent_boundarylogproxy_codec_boundary_proto_depIdxs = []int32{
2, // 0: coder.boundarylogproxy.codec.v1.BoundaryMessage.logs:type_name -> coder.agent.v2.ReportBoundaryLogsRequest
1, // 1: coder.boundarylogproxy.codec.v1.BoundaryMessage.status:type_name -> coder.boundarylogproxy.codec.v1.BoundaryStatus
2, // [2:2] is the sub-list for method output_type
2, // [2:2] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
}
func init() { file_agent_boundarylogproxy_codec_boundary_proto_init() }
func file_agent_boundarylogproxy_codec_boundary_proto_init() {
if File_agent_boundarylogproxy_codec_boundary_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*BoundaryMessage); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*BoundaryStatus); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0].OneofWrappers = []interface{}{
(*BoundaryMessage_Logs)(nil),
(*BoundaryMessage_Status)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_agent_boundarylogproxy_codec_boundary_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_agent_boundarylogproxy_codec_boundary_proto_goTypes,
DependencyIndexes: file_agent_boundarylogproxy_codec_boundary_proto_depIdxs,
MessageInfos: file_agent_boundarylogproxy_codec_boundary_proto_msgTypes,
}.Build()
File_agent_boundarylogproxy_codec_boundary_proto = out.File
file_agent_boundarylogproxy_codec_boundary_proto_rawDesc = nil
file_agent_boundarylogproxy_codec_boundary_proto_goTypes = nil
file_agent_boundarylogproxy_codec_boundary_proto_depIdxs = nil
}
@@ -0,0 +1,29 @@
syntax = "proto3";
option go_package = "github.com/coder/coder/v2/agent/boundarylogproxy/codec";
package coder.boundarylogproxy.codec.v1;
import "agent/proto/agent.proto";
// BoundaryMessage is the envelope for all TagV2 messages sent over the
// boundary <-> agent unix socket. TagV1 carries a bare
// ReportBoundaryLogsRequest for backwards compatibility; TagV2 wraps
// everything in this envelope so the protocol can be extended with new
// message types without adding more tags.
message BoundaryMessage {
oneof msg {
coder.agent.v2.ReportBoundaryLogsRequest logs = 1;
BoundaryStatus status = 2;
}
}
// BoundaryStatus carries operational metadata from boundary to the agent.
// The agent records these values as Prometheus metrics. This message is
// never forwarded to coderd.
message BoundaryStatus {
// Logs dropped because boundary's internal channel buffer was full.
int64 dropped_channel_full = 1;
// Logs dropped because boundary's batch buffer was full after a
// failed flush attempt.
int64 dropped_batch_full = 2;
}
+73 -14
View File
@@ -14,14 +14,23 @@ import (
"io"
"golang.org/x/xerrors"
"google.golang.org/protobuf/proto"
agentproto "github.com/coder/coder/v2/agent/proto"
)
type Tag uint8
const (
// TagV1 identifies the first revision of the protocol. This version has a maximum
// data length of MaxMessageSizeV1.
// TagV1 identifies the first revision of the protocol. The payload is a
// bare ReportBoundaryLogsRequest. This version has a maximum data length
// of MaxMessageSizeV1.
TagV1 Tag = 1
// TagV2 identifies the second revision of the protocol. The payload is
// a BoundaryMessage envelope. This version has a maximum data length of
// MaxMessageSizeV2.
TagV2 Tag = 2
)
const (
@@ -35,6 +44,9 @@ const (
// over the wire for the TagV1 tag. While the wire format allows 24 bits for
// length, TagV1 only uses 15 bits.
MaxMessageSizeV1 uint32 = 1 << 15
// MaxMessageSizeV2 is the maximum data length for TagV2.
MaxMessageSizeV2 = MaxMessageSizeV1
)
var (
@@ -48,12 +60,9 @@ var (
// WriteFrame writes a framed message with the given tag and data. The data
// must not exceed 2^DataLength in length.
func WriteFrame(w io.Writer, tag Tag, data []byte) error {
var maxSize uint32
switch tag {
case TagV1:
maxSize = MaxMessageSizeV1
default:
return xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
maxSize, err := maxSizeForTag(tag)
if err != nil {
return err
}
if len(data) > int(maxSize) {
@@ -101,12 +110,9 @@ func ReadFrame(r io.Reader, buf []byte) (Tag, []byte, error) {
}
tag := Tag(shifted)
var maxSize uint32
switch tag {
case TagV1:
maxSize = MaxMessageSizeV1
default:
return 0, nil, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
maxSize, err := maxSizeForTag(tag)
if err != nil {
return 0, nil, err
}
if length > maxSize {
@@ -125,3 +131,56 @@ func ReadFrame(r io.Reader, buf []byte) (Tag, []byte, error) {
return tag, buf[:length], nil
}
// maxSizeForTag returns the maximum payload size for the given tag.
func maxSizeForTag(tag Tag) (uint32, error) {
switch tag {
case TagV1:
return MaxMessageSizeV1, nil
case TagV2:
return MaxMessageSizeV2, nil
default:
return 0, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
}
}
// ReadMessage reads a framed message and unmarshals it based on tag. The
// returned buf should be passed back on the next call for buffer reuse.
func ReadMessage(r io.Reader, buf []byte) (proto.Message, []byte, error) {
tag, data, err := ReadFrame(r, buf)
if err != nil {
return nil, data, err
}
var msg proto.Message
switch tag {
case TagV1:
var req agentproto.ReportBoundaryLogsRequest
if err := proto.Unmarshal(data, &req); err != nil {
return nil, data, xerrors.Errorf("unmarshal TagV1: %w", err)
}
msg = &req
case TagV2:
var envelope BoundaryMessage
if err := proto.Unmarshal(data, &envelope); err != nil {
return nil, data, xerrors.Errorf("unmarshal TagV2: %w", err)
}
msg = &envelope
default:
// maxSizeForTag already rejects unknown tags during ReadFrame,
// but handle it here for safety.
return nil, data, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
}
return msg, data, nil
}
// WriteMessage marshals a proto message and writes it as a framed message
// with the given tag.
func WriteMessage(w io.Writer, tag Tag, msg proto.Message) error {
data, err := proto.Marshal(msg)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
}
return WriteFrame(w, tag, data)
}
+2 -2
View File
@@ -89,7 +89,7 @@ func TestReadFrameInvalidTag(t *testing.T) {
// reading the invalid tag.
const (
dataLength uint32 = 10
bogusTag uint32 = 2
bogusTag uint32 = 222
)
header := bogusTag<<codec.DataLength | dataLength
data := make([]byte, 4)
@@ -139,7 +139,7 @@ func TestWriteFrameInvalidTag(t *testing.T) {
var buf bytes.Buffer
data := make([]byte, 1)
const bogusTag = 2
const bogusTag = 222
err := codec.WriteFrame(&buf, codec.Tag(bogusTag), data)
require.ErrorIs(t, err, codec.ErrUnsupportedTag)
}
+77
View File
@@ -0,0 +1,77 @@
package boundarylogproxy
import "github.com/prometheus/client_golang/prometheus"
// Metrics tracks observability for the boundary -> agent -> coderd audit log
// pipeline.
//
// Audit logs from boundary workspaces pass through several async buffers
// before reaching coderd, and any stage can silently drop data. These
// metrics make that loss visible so operators/devs can:
//
// - Bubble up data loss: a non-zero drop rate means audit logs are being
// lost, which may have auditing implications.
// - Identify the bottleneck: the reason label pinpoints where drops
// occur: boundary's internal buffers, the agent's channel, or the
// RPC to coderd.
// - Tune buffer sizes: sustained "buffer_full" drops indicate the
// agent's channel (or boundary's batch buffer) is too small for the
// workload. Combined with batches_forwarded_total you can compute a
// drop rate: drops / (drops + forwards).
// - Detect batch forwarding issues: "forward_failed" drops increase when
// the agent cannot reach coderd.
//
// Drops are captured at two stages:
// - Agent-side: the agent's channel buffer overflows (reason
// "buffer_full") or the RPC forward to coderd fails (reason
// "forward_failed").
// - Boundary-reported: boundary self-reports drops via BoundaryStatus
// messages (reasons "boundary_channel_full", "boundary_batch_full").
// These arrive on the next successful flush from boundary.
//
// There are circumstances where metrics could be lost e.g., agent restarts,
// boundary crashes, or the agent shuts down when the DRPC connection is down.
type Metrics struct {
batchesDropped *prometheus.CounterVec
logsDropped *prometheus.CounterVec
batchesForwarded prometheus.Counter
}
func newMetrics(registerer prometheus.Registerer) *Metrics {
batchesDropped := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "agent",
Subsystem: "boundary_log_proxy",
Name: "batches_dropped_total",
Help: "Total number of boundary log batches dropped before reaching coderd. " +
"Reason: buffer_full = the agent's internal buffer is full, meaning boundary is producing logs faster than the agent can forward them to coderd; " +
"forward_failed = the agent failed to send the batch to coderd, potentially because coderd is unreachable or the connection was interrupted.",
}, []string{"reason"})
registerer.MustRegister(batchesDropped)
logsDropped := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "agent",
Subsystem: "boundary_log_proxy",
Name: "logs_dropped_total",
Help: "Total number of individual boundary log entries dropped before reaching coderd. " +
"Reason: buffer_full = the agent's internal buffer is full; " +
"forward_failed = the agent failed to send the batch to coderd; " +
"boundary_channel_full = boundary's internal send channel overflowed, meaning boundary is generating logs faster than it can batch and send them; " +
"boundary_batch_full = boundary's outgoing batch buffer overflowed after a failed flush, meaning boundary could not write to the agent's socket.",
}, []string{"reason"})
registerer.MustRegister(logsDropped)
batchesForwarded := prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "agent",
Subsystem: "boundary_log_proxy",
Name: "batches_forwarded_total",
Help: "Total number of boundary log batches successfully forwarded to coderd. " +
"Compare with batches_dropped_total to compute a drop rate.",
})
registerer.MustRegister(batchesForwarded)
return &Metrics{
batchesDropped: batchesDropped,
logsDropped: logsDropped,
batchesForwarded: batchesForwarded,
}
}
+60 -23
View File
@@ -11,6 +11,7 @@ import (
"path/filepath"
"sync"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"google.golang.org/protobuf/proto"
@@ -26,6 +27,13 @@ const (
logBufferSize = 100
)
const (
droppedReasonBoundaryChannelFull = "boundary_channel_full"
droppedReasonBoundaryBatchFull = "boundary_batch_full"
droppedReasonBufferFull = "buffer_full"
droppedReasonForwardFailed = "forward_failed"
)
// DefaultSocketPath returns the default path for the boundary audit log socket.
func DefaultSocketPath() string {
return filepath.Join(os.TempDir(), "boundary-audit.sock")
@@ -43,6 +51,7 @@ type Reporter interface {
type Server struct {
logger slog.Logger
socketPath string
metrics *Metrics
listener net.Listener
cancel context.CancelFunc
@@ -53,10 +62,11 @@ type Server struct {
}
// NewServer creates a new boundary log proxy server.
func NewServer(logger slog.Logger, socketPath string) *Server {
func NewServer(logger slog.Logger, socketPath string, registerer prometheus.Registerer) *Server {
return &Server{
logger: logger.Named("boundary-log-proxy"),
socketPath: socketPath,
metrics: newMetrics(registerer),
logs: make(chan *agentproto.ReportBoundaryLogsRequest, logBufferSize),
}
}
@@ -100,9 +110,13 @@ func (s *Server) RunForwarder(ctx context.Context, sender Reporter) error {
s.logger.Warn(ctx, "failed to forward boundary logs",
slog.Error(err),
slog.F("log_count", len(req.Logs)))
s.metrics.batchesDropped.WithLabelValues(droppedReasonForwardFailed).Inc()
s.metrics.logsDropped.WithLabelValues(droppedReasonForwardFailed).Add(float64(len(req.Logs)))
// Continue forwarding other logs. The current batch is lost,
// but the socket stays alive.
continue
}
s.metrics.batchesForwarded.Inc()
}
}
}
@@ -139,8 +153,8 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
_ = conn.Close()
}()
// This is intended to be a sane starting point for the read buffer size. It may be
// grown by codec.ReadFrame if necessary.
// This is intended to be a sane starting point for the read buffer size.
// It may be grown by codec.ReadMessage if necessary.
const initBufSize = 1 << 10
buf := make([]byte, initBufSize)
@@ -151,36 +165,59 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
default:
}
var (
tag codec.Tag
err error
)
tag, buf, err = codec.ReadFrame(conn, buf)
var err error
var msg proto.Message
msg, buf, err = codec.ReadMessage(conn, buf)
switch {
case errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed):
return
case err != nil:
case errors.Is(err, codec.ErrUnsupportedTag) || errors.Is(err, codec.ErrMessageTooLarge):
s.logger.Warn(ctx, "read frame error", slog.Error(err))
return
}
if tag != codec.TagV1 {
s.logger.Warn(ctx, "invalid tag value", slog.F("tag", tag))
return
}
var req agentproto.ReportBoundaryLogsRequest
if err := proto.Unmarshal(buf, &req); err != nil {
s.logger.Warn(ctx, "proto unmarshal error", slog.Error(err))
case err != nil:
s.logger.Warn(ctx, "read message error", slog.Error(err))
continue
}
select {
case s.logs <- &req:
s.handleMessage(ctx, msg)
}
}
func (s *Server) handleMessage(ctx context.Context, msg proto.Message) {
switch m := msg.(type) {
case *agentproto.ReportBoundaryLogsRequest:
s.bufferLogs(ctx, m)
case *codec.BoundaryMessage:
switch inner := m.Msg.(type) {
case *codec.BoundaryMessage_Logs:
s.bufferLogs(ctx, inner.Logs)
case *codec.BoundaryMessage_Status:
s.recordBoundaryStatus(inner.Status)
default:
s.logger.Warn(ctx, "dropping boundary logs, buffer full",
slog.F("log_count", len(req.Logs)))
s.logger.Warn(ctx, "unknown BoundaryMessage variant")
}
default:
s.logger.Warn(ctx, "unexpected message type")
}
}
func (s *Server) recordBoundaryStatus(status *codec.BoundaryStatus) {
if n := status.DroppedChannelFull; n > 0 {
s.metrics.logsDropped.WithLabelValues(droppedReasonBoundaryChannelFull).Add(float64(n))
}
if n := status.DroppedBatchFull; n > 0 {
s.metrics.logsDropped.WithLabelValues(droppedReasonBoundaryBatchFull).Add(float64(n))
}
}
func (s *Server) bufferLogs(ctx context.Context, req *agentproto.ReportBoundaryLogsRequest) {
select {
case s.logs <- req:
default:
s.logger.Warn(ctx, "dropping boundary logs, buffer full",
slog.F("log_count", len(req.Logs)))
s.metrics.batchesDropped.WithLabelValues(droppedReasonBufferFull).Inc()
s.metrics.logsDropped.WithLabelValues(droppedReasonBufferFull).Add(float64(len(req.Logs)))
}
}
+303 -26
View File
@@ -11,8 +11,8 @@ import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/coder/coder/v2/agent/boundarylogproxy"
@@ -21,20 +21,42 @@ import (
"github.com/coder/coder/v2/testutil"
)
// sendMessage writes a framed protobuf message to the connection.
func sendMessage(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
// sendLogsV1 writes a bare ReportBoundaryLogsRequest using TagV1, the
// legacy framing that existing boundary deployments use.
func sendLogsV1(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
t.Helper()
data, err := proto.Marshal(req)
err := codec.WriteMessage(conn, codec.TagV1, req)
if err != nil {
//nolint:gocritic // In tests we're not worried about conn being nil.
t.Errorf("%s marshal req: %s", conn.LocalAddr().String(), err)
t.Errorf("write v1 logs: %s", err)
}
}
err = codec.WriteFrame(conn, codec.TagV1, data)
// sendLogs writes a BoundaryMessage envelope containing logs to the
// connection using TagV2.
func sendLogs(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
t.Helper()
msg := &codec.BoundaryMessage{
Msg: &codec.BoundaryMessage_Logs{Logs: req},
}
err := codec.WriteMessage(conn, codec.TagV2, msg)
if err != nil {
//nolint:gocritic // In tests we're not worried about conn being nil.
t.Errorf("%s write frame: %s", conn.LocalAddr().String(), err)
t.Errorf("write logs: %s", err)
}
}
// sendStatus writes a BoundaryMessage envelope containing a BoundaryStatus
// to the connection using TagV2.
func sendStatus(t *testing.T, conn net.Conn, status *codec.BoundaryStatus) {
t.Helper()
msg := &codec.BoundaryMessage{
Msg: &codec.BoundaryMessage_Status{Status: status},
}
err := codec.WriteMessage(conn, codec.TagV2, msg)
if err != nil {
t.Errorf("write status: %s", err)
}
}
@@ -80,7 +102,7 @@ func TestServer_StartAndClose(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
err := srv.Start()
require.NoError(t, err)
@@ -99,7 +121,7 @@ func TestServer_ReceiveAndForwardLogs(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -136,7 +158,7 @@ func TestServer_ReceiveAndForwardLogs(t *testing.T) {
},
}
sendMessage(t, conn, req)
sendLogs(t, conn, req)
// Wait for the reporter to receive the log.
require.Eventually(t, func() bool {
@@ -159,7 +181,7 @@ func TestServer_MultipleMessages(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -195,7 +217,7 @@ func TestServer_MultipleMessages(t *testing.T) {
},
},
}
sendMessage(t, conn, req)
sendLogs(t, conn, req)
}
require.Eventually(t, func() bool {
@@ -211,7 +233,7 @@ func TestServer_MultipleConnections(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -254,7 +276,7 @@ func TestServer_MultipleConnections(t *testing.T) {
},
},
}
sendMessage(t, conn, req)
sendLogs(t, conn, req)
}(i)
}
wg.Wait()
@@ -272,7 +294,7 @@ func TestServer_MessageTooLarge(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
err := srv.Start()
require.NoError(t, err)
@@ -300,7 +322,7 @@ func TestServer_ForwarderContinuesAfterError(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
err := srv.Start()
require.NoError(t, err)
@@ -342,7 +364,7 @@ func TestServer_ForwarderContinuesAfterError(t *testing.T) {
},
},
}
sendMessage(t, conn, req1)
sendLogs(t, conn, req1)
select {
case <-reportNotify:
@@ -365,7 +387,7 @@ func TestServer_ForwarderContinuesAfterError(t *testing.T) {
},
},
}
sendMessage(t, conn, req2)
sendLogs(t, conn, req2)
// Only the second message should be recorded.
require.Eventually(t, func() bool {
@@ -385,7 +407,7 @@ func TestServer_CloseStopsForwarder(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
err := srv.Start()
require.NoError(t, err)
@@ -414,7 +436,7 @@ func TestServer_InvalidProtobuf(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
err := srv.Start()
require.NoError(t, err)
@@ -458,7 +480,7 @@ func TestServer_InvalidProtobuf(t *testing.T) {
},
},
}
sendMessage(t, conn, req)
sendLogs(t, conn, req)
require.Eventually(t, func() bool {
logs := reporter.getLogs()
@@ -473,7 +495,7 @@ func TestServer_InvalidHeader(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
err := srv.Start()
require.NoError(t, err)
@@ -523,7 +545,7 @@ func TestServer_AllowRequest(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
err := srv.Start()
require.NoError(t, err)
@@ -559,7 +581,7 @@ func TestServer_AllowRequest(t *testing.T) {
},
},
}
sendMessage(t, conn, req)
sendLogs(t, conn, req)
require.Eventually(t, func() bool {
logs := reporter.getLogs()
@@ -576,3 +598,258 @@ func TestServer_AllowRequest(t *testing.T) {
cancel()
<-forwarderDone
}
func TestServer_TagV1BackwardsCompatibility(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := srv.Start()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
reporter := &fakeReporter{}
forwarderDone := make(chan error, 1)
go func() {
forwarderDone <- srv.RunForwarder(ctx, reporter)
}()
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Send a TagV1 message (bare ReportBoundaryLogsRequest) to verify
// the server still handles the legacy framing used by existing
// boundary deployments.
v1Req := &agentproto.ReportBoundaryLogsRequest{
Logs: []*agentproto.BoundaryLog{
{
Allowed: true,
Time: timestamppb.Now(),
Resource: &agentproto.BoundaryLog_HttpRequest_{
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
Method: "GET",
Url: "https://example.com/v1",
},
},
},
},
}
sendLogsV1(t, conn, v1Req)
require.Eventually(t, func() bool {
return len(reporter.getLogs()) == 1
}, testutil.WaitShort, testutil.IntervalFast)
// Now send a TagV2 message on the same connection to verify both
// tag versions work interleaved.
v2Req := &agentproto.ReportBoundaryLogsRequest{
Logs: []*agentproto.BoundaryLog{
{
Allowed: false,
Time: timestamppb.Now(),
Resource: &agentproto.BoundaryLog_HttpRequest_{
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
Method: "POST",
Url: "https://example.com/v2",
},
},
},
},
}
sendLogs(t, conn, v2Req)
require.Eventually(t, func() bool {
return len(reporter.getLogs()) == 2
}, testutil.WaitShort, testutil.IntervalFast)
logs := reporter.getLogs()
require.Equal(t, "https://example.com/v1", logs[0].Logs[0].GetHttpRequest().Url)
require.Equal(t, "https://example.com/v2", logs[1].Logs[0].GetHttpRequest().Url)
cancel()
<-forwarderDone
}
func TestServer_Metrics(t *testing.T) {
t.Parallel()
makeReq := func(n int) *agentproto.ReportBoundaryLogsRequest {
logs := make([]*agentproto.BoundaryLog, n)
for i := range n {
logs[i] = &agentproto.BoundaryLog{
Allowed: true,
Time: timestamppb.Now(),
Resource: &agentproto.BoundaryLog_HttpRequest_{
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
Method: "GET",
Url: "https://example.com",
},
},
}
}
return &agentproto.ReportBoundaryLogsRequest{Logs: logs}
}
// BufferFull needs its own setup because it intentionally does not run
// a forwarder so the channel fills up.
t.Run("BufferFull", func(t *testing.T) {
t.Parallel()
reg := prometheus.NewRegistry()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, reg)
err := srv.Start()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Fill the buffer (size 100) without running a forwarder so nothing
// drains. Then send one more to trigger the drop path.
for range 101 {
sendLogs(t, conn, makeReq(1))
}
require.Eventually(t, func() bool {
return getCounterVecValue(t, reg, "agent_boundary_log_proxy_batches_dropped_total", "buffer_full") >= 1
}, testutil.WaitShort, testutil.IntervalFast)
require.GreaterOrEqual(t,
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "buffer_full"),
float64(1))
})
// The remaining metrics share one server, forwarder, and connection. The
// phases run sequentially so metrics accumulate.
t.Run("Forwarding", func(t *testing.T) {
t.Parallel()
reg := prometheus.NewRegistry()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, reg)
err := srv.Start()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
reportNotify := make(chan struct{}, 4)
reporter := &fakeReporter{
err: context.DeadlineExceeded,
errOnce: true,
reportCb: func() {
select {
case reportNotify <- struct{}{}:
default:
}
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
forwarderDone := make(chan error, 1)
go func() {
forwarderDone <- srv.RunForwarder(ctx, reporter)
}()
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Phase 1: the first forward errors
sendLogs(t, conn, makeReq(2))
select {
case <-reportNotify:
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for forward attempt")
}
// The metric is incremented after ReportBoundaryLogs returns, so we
// need to poll briefly.
require.Eventually(t, func() bool {
return getCounterVecValue(t, reg, "agent_boundary_log_proxy_batches_dropped_total", "forward_failed") >= 1
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, float64(2),
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "forward_failed"))
// Phase 2: forward succeeds.
sendLogs(t, conn, makeReq(1))
require.Eventually(t, func() bool {
return len(reporter.getLogs()) >= 1
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, float64(1),
getCounterValue(t, reg, "agent_boundary_log_proxy_batches_forwarded_total"))
// Phase 3: boundary-reported drop counts arrive as a separate BoundaryStatus
// message, not piggybacked on log batches.
sendStatus(t, conn, &codec.BoundaryStatus{
DroppedChannelFull: 5,
DroppedBatchFull: 3,
})
// Status is handled immediately by the reader goroutine, not by the
// forwarder, so poll metrics directly.
require.Eventually(t, func() bool {
return getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "boundary_channel_full") >= 5
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, float64(5),
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "boundary_channel_full"))
require.Equal(t, float64(3),
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "boundary_batch_full"))
cancel()
<-forwarderDone
})
}
// getCounterVecValue returns the current value of a CounterVec metric filtered
// by the given reason label.
func getCounterVecValue(t *testing.T, reg *prometheus.Registry, name, reason string) float64 {
t.Helper()
metrics, err := reg.Gather()
require.NoError(t, err)
for _, mf := range metrics {
if mf.GetName() != name {
continue
}
for _, m := range mf.GetMetric() {
for _, lp := range m.GetLabel() {
if lp.GetName() == "reason" && lp.GetValue() == reason {
return m.GetCounter().GetValue()
}
}
}
}
return 0
}
// getCounterValue returns the current value of a Counter metric.
func getCounterValue(t *testing.T, reg *prometheus.Registry, name string) float64 {
t.Helper()
metrics, err := reg.Gather()
require.NoError(t, err)
for _, mf := range metrics {
if mf.GetName() != name {
continue
}
for _, m := range mf.GetMetric() {
return m.GetCounter().GetValue()
}
}
return 0
}
+316
View File
@@ -0,0 +1,316 @@
package filefinder_test
import (
"context"
"fmt"
"math/rand"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/filefinder"
)
var (
dirNames = []string{
"cmd", "internal", "pkg", "api", "auth", "database", "server", "client", "middleware",
"handler", "config", "utils", "models", "service", "worker", "scheduler", "notification",
"provisioner", "template", "workspace", "agent", "proxy", "crypto", "telemetry", "billing",
}
fileExts = []string{
".go", ".ts", ".tsx", ".js", ".py", ".sql", ".yaml", ".json", ".md", ".proto", ".sh",
}
fileStems = []string{
"main", "handler", "middleware", "service", "model", "query", "config", "utils", "helpers",
"types", "interface", "test", "mock", "factory", "builder", "adapter", "observer", "provider",
"resolver", "schema", "migration", "fixture", "snapshot", "checkpoint",
}
)
// generateFileTree creates n files under root in a realistic nested directory structure.
func generateFileTree(t testing.TB, root string, n int, seed int64) {
t.Helper()
rng := rand.New(rand.NewSource(seed)) //nolint:gosec // deterministic benchmarks
numDirs := n / 5
if numDirs < 10 {
numDirs = 10
}
dirs := make([]string, 0, numDirs)
for i := 0; i < numDirs; i++ {
depth := rng.Intn(6) + 1
parts := make([]string, depth)
for d := 0; d < depth; d++ {
parts[d] = dirNames[rng.Intn(len(dirNames))]
}
dirs = append(dirs, filepath.Join(parts...))
}
created := make(map[string]struct{})
for _, d := range dirs {
full := filepath.Join(root, d)
if _, ok := created[full]; ok {
continue
}
require.NoError(t, os.MkdirAll(full, 0o755))
created[full] = struct{}{}
}
for i := 0; i < n; i++ {
dir := dirs[rng.Intn(len(dirs))]
stem := fileStems[rng.Intn(len(fileStems))]
ext := fileExts[rng.Intn(len(fileExts))]
name := fmt.Sprintf("%s_%d%s", stem, i, ext)
full := filepath.Join(root, dir, name)
f, err := os.Create(full)
require.NoError(t, err)
_ = f.Close()
}
}
// buildIndex walks root and returns a populated Index, the same
// way Engine.AddRoot does but without starting a watcher.
func buildIndex(t testing.TB, root string) *filefinder.Index {
t.Helper()
absRoot, err := filepath.Abs(root)
require.NoError(t, err)
idx, err := filefinder.BuildTestIndex(absRoot)
require.NoError(t, err)
return idx
}
func BenchmarkBuildIndex(b *testing.B) {
scales := []struct {
name string
n int
}{
{"1K", 1_000},
{"10K", 10_000},
{"100K", 100_000},
}
for _, sc := range scales {
b.Run(sc.name, func(b *testing.B) {
if sc.n >= 100_000 && testing.Short() {
b.Skip("skipping large-scale benchmark")
}
dir := b.TempDir()
generateFileTree(b, dir, sc.n, 42)
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := buildIndex(b, dir)
if idx.Len() == 0 {
b.Fatal("expected non-empty index")
}
}
b.StopTimer()
idx := buildIndex(b, dir)
b.ReportMetric(float64(idx.Len())/b.Elapsed().Seconds(), "files/sec")
})
}
}
func BenchmarkSearch_ByScale(b *testing.B) {
queries := []struct {
name string
query string
}{
{"exact_basename", "handler.go"},
{"short_query", "ha"},
{"fuzzy_basename", "hndlr"},
{"path_structured", "internal/handler"},
{"multi_token", "api handler"},
}
scales := []struct {
name string
n int
}{
{"1K", 1_000},
{"10K", 10_000},
{"100K", 100_000},
}
for _, sc := range scales {
b.Run(sc.name, func(b *testing.B) {
if sc.n >= 100_000 && testing.Short() {
b.Skip("skipping large-scale benchmark")
}
dir := b.TempDir()
generateFileTree(b, dir, sc.n, 42)
idx := buildIndex(b, dir)
snap := idx.Snapshot()
opts := filefinder.DefaultSearchOptions()
for _, q := range queries {
b.Run(q.name, func(b *testing.B) {
p := filefinder.NewQueryPlanForTest(q.query)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = filefinder.SearchSnapshotForTest(p, snap, opts.MaxCandidates)
}
})
}
})
}
}
func BenchmarkSearch_ConcurrentReads(b *testing.B) {
dir := b.TempDir()
generateFileTree(b, dir, 10_000, 42)
logger := slogtest.Make(b, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelError)
ctx := context.Background()
eng := filefinder.NewEngine(logger)
require.NoError(b, eng.AddRoot(ctx, dir))
b.Cleanup(func() { _ = eng.Close() })
opts := filefinder.DefaultSearchOptions()
goroutines := []int{1, 4, 16, 64}
for _, g := range goroutines {
b.Run(fmt.Sprintf("goroutines_%d", g), func(b *testing.B) {
b.SetParallelism(g)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
results, err := eng.Search(ctx, "handler", opts)
if err != nil {
b.Fatal(err)
}
_ = results
}
})
})
}
}
func BenchmarkDeltaUpdate(b *testing.B) {
dir := b.TempDir()
generateFileTree(b, dir, 10_000, 42)
addCounts := []int{1, 10, 100}
for _, count := range addCounts {
b.Run(fmt.Sprintf("add_%d_files", count), func(b *testing.B) {
paths := make([]string, count)
for i := range paths {
paths[i] = fmt.Sprintf("injected/dir_%d/newfile_%d.go", i%10, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
idx := buildIndex(b, dir)
b.StartTimer()
for _, p := range paths {
idx.Add(p, 0)
}
}
b.ReportMetric(float64(count), "files_added/op")
})
}
b.Run("search_after_100_additions", func(b *testing.B) {
idx := buildIndex(b, dir)
for i := 0; i < 100; i++ {
idx.Add(fmt.Sprintf("injected/extra/file_%d.go", i), 0)
}
snap := idx.Snapshot()
plan := filefinder.NewQueryPlanForTest("handler")
opts := filefinder.DefaultSearchOptions()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = filefinder.SearchSnapshotForTest(plan, snap, opts.MaxCandidates)
}
})
}
func BenchmarkMemoryProfile(b *testing.B) {
scales := []struct {
name string
n int
}{
{"10K", 10_000},
{"100K", 100_000},
}
for _, sc := range scales {
b.Run(sc.name, func(b *testing.B) {
if sc.n >= 100_000 && testing.Short() {
b.Skip("skipping large-scale memory profile")
}
dir := b.TempDir()
generateFileTree(b, dir, sc.n, 42)
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := buildIndex(b, dir)
_ = idx.Snapshot()
}
b.StopTimer()
// Report memory stats on the last iteration.
runtime.GC()
var before runtime.MemStats
runtime.ReadMemStats(&before)
idx := buildIndex(b, dir)
var after runtime.MemStats
runtime.ReadMemStats(&after)
allocDelta := after.TotalAlloc - before.TotalAlloc
b.ReportMetric(float64(allocDelta)/float64(idx.Len()), "bytes/file")
runtime.GC()
runtime.ReadMemStats(&before)
snap := idx.Snapshot()
_ = snap
runtime.GC()
runtime.ReadMemStats(&after)
snapAlloc := after.TotalAlloc - before.TotalAlloc
b.ReportMetric(float64(snapAlloc)/float64(idx.Len()), "snap-bytes/file")
})
}
}
func BenchmarkSearch_ConcurrentReads_Throughput(b *testing.B) {
dir := b.TempDir()
generateFileTree(b, dir, 10_000, 42)
idx := buildIndex(b, dir)
snap := idx.Snapshot()
goroutines := []int{1, 4, 16, 64}
plan := filefinder.NewQueryPlanForTest("handler.go")
maxCands := filefinder.DefaultSearchOptions().MaxCandidates
for _, g := range goroutines {
b.Run(fmt.Sprintf("goroutines_%d", g), func(b *testing.B) {
b.ResetTimer()
var wg sync.WaitGroup
perGoroutine := b.N / g
if perGoroutine < 1 {
perGoroutine = 1
}
for gi := 0; gi < g; gi++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < perGoroutine; j++ {
_ = filefinder.SearchSnapshotForTest(plan, snap, maxCands)
}
}()
}
wg.Wait()
totalOps := float64(g * perGoroutine)
b.ReportMetric(totalOps/b.Elapsed().Seconds(), "searches/sec")
})
}
}
+125
View File
@@ -0,0 +1,125 @@
package filefinder
import "strings"
// FileFlag represents the type of filesystem entry.
type FileFlag uint16
const (
FlagFile FileFlag = 0
FlagDir FileFlag = 1
FlagSymlink FileFlag = 2
)
type doc struct {
path string
baseOff int
baseLen int
depth int
flags uint16
}
// Index is an append-only in-memory file index with snapshot support.
type Index struct {
docs []doc
byGram map[uint32][]uint32
byPrefix1 [256][]uint32
byPrefix2 map[uint16][]uint32
byPath map[string]uint32
deleted map[uint32]bool
}
// Snapshot is a frozen, read-only view of the index at a point in time.
type Snapshot struct {
docs []doc
deleted map[uint32]bool
byGram map[uint32][]uint32
byPrefix1 [256][]uint32
byPrefix2 map[uint16][]uint32
}
// NewIndex creates an empty Index.
func NewIndex() *Index {
return &Index{
byGram: make(map[uint32][]uint32),
byPrefix2: make(map[uint16][]uint32),
byPath: make(map[string]uint32),
deleted: make(map[uint32]bool),
}
}
// Add inserts a path into the index, tombstoning any previous entry.
func (idx *Index) Add(path string, flags uint16) uint32 {
norm := string(normalizePathBytes([]byte(path)))
if oldID, ok := idx.byPath[norm]; ok {
idx.deleted[oldID] = true
}
id := uint32(len(idx.docs)) //nolint:gosec // Index will never exceed 2^32 docs.
baseOff, baseLen := extractBasename([]byte(norm))
idx.docs = append(idx.docs, doc{
path: norm, baseOff: baseOff, baseLen: baseLen,
depth: strings.Count(norm, "/"), flags: flags,
})
idx.byPath[norm] = id
for _, g := range extractTrigrams([]byte(norm)) {
idx.byGram[g] = append(idx.byGram[g], id)
}
if baseLen > 0 {
basename := []byte(norm[baseOff : baseOff+baseLen])
p1 := prefix1(basename)
idx.byPrefix1[p1] = append(idx.byPrefix1[p1], id)
p2 := prefix2(basename)
idx.byPrefix2[p2] = append(idx.byPrefix2[p2], id)
}
return id
}
// Remove marks the entry for path as deleted.
func (idx *Index) Remove(path string) bool {
norm := string(normalizePathBytes([]byte(path)))
id, ok := idx.byPath[norm]
if !ok {
return false
}
idx.deleted[id] = true
delete(idx.byPath, norm)
return true
}
// Has reports whether path exists (not deleted) in the index.
func (idx *Index) Has(path string) bool {
_, ok := idx.byPath[string(normalizePathBytes([]byte(path)))]
return ok
}
// Len returns the number of live (non-deleted) documents.
func (idx *Index) Len() int { return len(idx.byPath) }
func copyPostings[K comparable](m map[K][]uint32) map[K][]uint32 {
cp := make(map[K][]uint32, len(m))
for k, v := range m {
cp[k] = v[:len(v):len(v)]
}
return cp
}
// Snapshot returns a frozen read-only view of the index.
func (idx *Index) Snapshot() *Snapshot {
del := make(map[uint32]bool, len(idx.deleted))
for id := range idx.deleted {
del[id] = true
}
var p1Copy [256][]uint32
for i, ids := range idx.byPrefix1 {
if len(ids) > 0 {
p1Copy[i] = ids[:len(ids):len(ids)]
}
}
return &Snapshot{
docs: idx.docs[:len(idx.docs):len(idx.docs)],
deleted: del,
byGram: copyPostings(idx.byGram),
byPrefix1: p1Copy,
byPrefix2: copyPostings(idx.byPrefix2),
}
}
+120
View File
@@ -0,0 +1,120 @@
package filefinder_test
import (
"testing"
"github.com/coder/coder/v2/agent/filefinder"
)
func TestIndex_AddAndLen(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
idx.Add("foo/baz.go", 0)
if idx.Len() != 2 {
t.Fatalf("expected 2, got %d", idx.Len())
}
}
func TestIndex_Has(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
if !idx.Has("foo/bar.go") {
t.Fatal("expected Has to return true")
}
if idx.Has("foo/missing.go") {
t.Fatal("expected Has to return false for missing path")
}
}
func TestIndex_Remove(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
if !idx.Remove("foo/bar.go") {
t.Fatal("expected Remove to return true")
}
if idx.Has("foo/bar.go") {
t.Fatal("expected Has to return false after Remove")
}
if idx.Len() != 0 {
t.Fatalf("expected Len 0 after Remove, got %d", idx.Len())
}
}
func TestIndex_AddOverwrite(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", uint16(filefinder.FlagFile))
idx.Add("foo/bar.go", uint16(filefinder.FlagDir)) // overwrite
if idx.Len() != 1 {
t.Fatalf("expected 1 after overwrite, got %d", idx.Len())
}
// The old entry should be tombstoned.
if !filefinder.IndexIsDeleted(idx, 0) {
t.Fatal("expected old entry to be deleted")
}
if filefinder.IndexIsDeleted(idx, 1) {
t.Fatal("expected new entry to be live")
}
}
func TestIndex_Snapshot(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
idx.Add("foo/baz.go", 0)
snap := idx.Snapshot()
if filefinder.SnapshotCount(snap) != 2 {
t.Fatalf("expected snapshot count 2, got %d", filefinder.SnapshotCount(snap))
}
// Adding more docs after snapshot doesn't affect it.
idx.Add("foo/qux.go", 0)
if filefinder.SnapshotCount(snap) != 2 {
t.Fatal("snapshot count should not change after new adds")
}
}
func TestIndex_TrigramIndex(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("handler.go", 0)
// "handler.go" should produce trigrams for "handler.go".
// Check that at least one trigram exists.
if filefinder.IndexByGramLen(idx) == 0 {
t.Fatal("expected non-empty trigram index")
}
}
func TestIndex_PrefixIndex(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("handler.go", 0)
// basename is "handler.go", first byte is 'h'
if filefinder.IndexByPrefix1Len(idx, 'h') == 0 {
t.Fatal("expected prefix1['h'] to be non-empty")
}
}
func TestIndex_RemoveNonexistent(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
if idx.Remove("nonexistent.go") {
t.Fatal("expected Remove to return false for missing path")
}
}
func TestIndex_PathNormalization(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("Foo/Bar.go", 0)
// Should be findable with lowercase.
if !idx.Has("foo/bar.go") {
t.Fatal("expected case-insensitive Has")
}
}
+364
View File
@@ -0,0 +1,364 @@
// Package filefinder provides an in-memory file index with trigram
// matching, fuzzy search, and filesystem watching. It is designed
// to power file-finding features on workspace agents.
package filefinder
import (
"context"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"sync/atomic"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
// SearchOptions controls search behavior.
type SearchOptions struct {
Limit int
MaxCandidates int
}
// DefaultSearchOptions returns sensible default search options.
func DefaultSearchOptions() SearchOptions {
return SearchOptions{Limit: 100, MaxCandidates: 10000}
}
type rootSnapshot struct {
root string
snap *Snapshot
}
// Engine is the main file finder. Safe for concurrent use.
type Engine struct {
snap atomic.Pointer[[]*rootSnapshot]
logger slog.Logger
mu sync.Mutex
roots map[string]*rootState
eventCh chan rootEvent
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
type rootState struct {
root string
index *Index
watcher *fsWatcher
cancel context.CancelFunc
}
type rootEvent struct {
root string
events []FSEvent
}
// walkRoot performs a full filesystem walk of absRoot and returns
// a populated Index containing all discovered files and directories.
func walkRoot(absRoot string) (*Index, error) {
idx := NewIndex()
err := filepath.Walk(absRoot, func(path string, info os.FileInfo, walkErr error) error {
if walkErr != nil {
return nil //nolint:nilerr
}
base := filepath.Base(path)
if _, skip := skipDirs[base]; skip && info.IsDir() {
return filepath.SkipDir
}
if path == absRoot {
return nil
}
relPath, relErr := filepath.Rel(absRoot, path)
if relErr != nil {
return nil //nolint:nilerr
}
relPath = filepath.ToSlash(relPath)
var flags uint16
if info.IsDir() {
flags = uint16(FlagDir)
} else if info.Mode()&os.ModeSymlink != 0 {
flags = uint16(FlagSymlink)
}
idx.Add(relPath, flags)
return nil
})
return idx, err
}
// NewEngine creates a new Engine.
func NewEngine(logger slog.Logger) *Engine {
e := &Engine{
logger: logger,
roots: make(map[string]*rootState),
eventCh: make(chan rootEvent, 256),
closeCh: make(chan struct{}),
}
empty := make([]*rootSnapshot, 0)
e.snap.Store(&empty)
e.wg.Add(1)
go e.start()
return e
}
// ErrClosed is returned when operations are attempted on a
// closed engine.
var ErrClosed = xerrors.New("engine is closed")
// AddRoot adds a directory root to the engine.
func (e *Engine) AddRoot(ctx context.Context, root string) error {
absRoot, err := filepath.Abs(root)
if err != nil {
return xerrors.Errorf("resolve root: %w", err)
}
e.mu.Lock()
if e.closed.Load() {
e.mu.Unlock()
return ErrClosed
}
if _, exists := e.roots[absRoot]; exists {
e.mu.Unlock()
return nil
}
e.mu.Unlock()
// Walk and create the watcher outside the lock to avoid
// blocking the event pipeline on filesystem I/O.
idx, walkErr := walkRoot(absRoot)
if walkErr != nil {
return xerrors.Errorf("walk root: %w", walkErr)
}
wCtx, wCancel := context.WithCancel(context.Background())
w, wErr := newFSWatcher(absRoot, e.logger)
if wErr != nil {
wCancel()
return xerrors.Errorf("create watcher: %w", wErr)
}
e.mu.Lock()
// Re-check after re-acquiring the lock: another goroutine
// may have added this root or closed the engine while we
// were walking.
if e.closed.Load() {
e.mu.Unlock()
wCancel()
_ = w.Close()
return ErrClosed
}
if _, exists := e.roots[absRoot]; exists {
e.mu.Unlock()
wCancel()
_ = w.Close()
return nil
}
rs := &rootState{root: absRoot, index: idx, watcher: w, cancel: wCancel}
e.roots[absRoot] = rs
w.Start(wCtx)
e.wg.Add(1)
go e.forwardEvents(wCtx, absRoot, w)
e.publishSnapshot()
fileCount := idx.Len()
e.mu.Unlock()
e.logger.Info(ctx, "added root to engine",
slog.F("root", absRoot),
slog.F("files", fileCount),
)
return nil
}
// RemoveRoot stops watching a root and removes it.
func (e *Engine) RemoveRoot(root string) error {
absRoot, err := filepath.Abs(root)
if err != nil {
return xerrors.Errorf("resolve root: %w", err)
}
e.mu.Lock()
defer e.mu.Unlock()
rs, exists := e.roots[absRoot]
if !exists {
return xerrors.Errorf("root %q not found", absRoot)
}
rs.cancel()
_ = rs.watcher.Close()
delete(e.roots, absRoot)
e.publishSnapshot()
return nil
}
// Search performs a fuzzy file search across all roots.
func (e *Engine) Search(_ context.Context, query string, opts SearchOptions) ([]Result, error) {
if e.closed.Load() {
return nil, ErrClosed
}
snapPtr := e.snap.Load()
if snapPtr == nil || len(*snapPtr) == 0 {
return nil, nil
}
roots := *snapPtr
plan := newQueryPlan(query)
if len(plan.Normalized) == 0 {
return nil, nil
}
if opts.Limit <= 0 {
opts.Limit = 100
}
if opts.MaxCandidates <= 0 {
opts.MaxCandidates = 10000
}
params := defaultScoreParams()
var allCands []candidate
for _, rs := range roots {
allCands = append(allCands, searchSnapshot(plan, rs.snap, opts.MaxCandidates)...)
}
results := mergeAndScore(allCands, plan, params, opts.Limit)
return results, nil
}
// Close shuts down the engine.
func (e *Engine) Close() error {
if e.closed.Swap(true) {
return nil
}
close(e.closeCh)
e.mu.Lock()
for _, rs := range e.roots {
rs.cancel()
_ = rs.watcher.Close()
}
e.roots = make(map[string]*rootState)
e.mu.Unlock()
e.wg.Wait()
return nil
}
// Rebuild forces a complete re-walk and re-index of a root.
func (e *Engine) Rebuild(ctx context.Context, root string) error {
absRoot, err := filepath.Abs(root)
if err != nil {
return xerrors.Errorf("resolve root: %w", err)
}
// Walk outside the lock to avoid blocking the event
// pipeline on potentially slow filesystem I/O.
idx, walkErr := walkRoot(absRoot)
if walkErr != nil {
return xerrors.Errorf("rebuild walk: %w", walkErr)
}
e.mu.Lock()
rs, exists := e.roots[absRoot]
if !exists {
e.mu.Unlock()
return xerrors.Errorf("root %q not found", absRoot)
}
rs.index = idx
e.publishSnapshot()
fileCount := idx.Len()
e.mu.Unlock()
e.logger.Info(ctx, "rebuilt root in engine",
slog.F("root", absRoot),
slog.F("files", fileCount),
)
return nil
}
func (e *Engine) start() {
defer e.wg.Done()
for {
select {
case <-e.closeCh:
return
case re, ok := <-e.eventCh:
if !ok {
return
}
e.applyEvents(re)
}
}
}
func (e *Engine) forwardEvents(ctx context.Context, root string, w *fsWatcher) {
defer e.wg.Done()
for {
select {
case <-ctx.Done():
return
case <-e.closeCh:
return
case evts, ok := <-w.Events():
if !ok {
return
}
select {
case e.eventCh <- rootEvent{root: root, events: evts}:
case <-ctx.Done():
return
case <-e.closeCh:
return
}
}
}
}
func (e *Engine) applyEvents(re rootEvent) {
e.mu.Lock()
defer e.mu.Unlock()
rs, exists := e.roots[re.root]
if !exists {
return
}
changed := false
for _, ev := range re.events {
relPath, err := filepath.Rel(rs.root, ev.Path)
if err != nil {
continue
}
relPath = filepath.ToSlash(relPath)
switch ev.Op {
case OpCreate:
if rs.index.Has(relPath) {
continue
}
var flags uint16
if ev.IsDir {
flags = uint16(FlagDir)
}
rs.index.Add(relPath, flags)
changed = true
case OpRemove, OpRename:
if rs.index.Remove(relPath) {
changed = true
}
if ev.IsDir || ev.Op == OpRename {
prefix := strings.ToLower(filepath.ToSlash(relPath)) + "/"
for path := range rs.index.byPath {
if strings.HasPrefix(path, prefix) {
rs.index.Remove(path)
changed = true
}
}
}
case OpModify:
}
}
if changed {
e.publishSnapshot()
}
}
// publishSnapshot builds and atomically publishes a new snapshot.
// Must be called with e.mu held.
func (e *Engine) publishSnapshot() {
roots := make([]*rootSnapshot, 0, len(e.roots))
for _, rs := range e.roots {
roots = append(roots, &rootSnapshot{
root: rs.root,
snap: rs.index.Snapshot(),
})
}
slices.SortFunc(roots, func(a, b *rootSnapshot) int {
return strings.Compare(a.root, b.root)
})
e.snap.Store(&roots)
}
+233
View File
@@ -0,0 +1,233 @@
package filefinder_test
import (
"context"
"os"
"path/filepath"
"sort"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/filefinder"
"github.com/coder/coder/v2/testutil"
)
func newTestEngine(t *testing.T) (*filefinder.Engine, context.Context) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
eng := filefinder.NewEngine(logger)
t.Cleanup(func() { _ = eng.Close() })
return eng, context.Background()
}
func requireResultHasPath(t *testing.T, results []filefinder.Result, path string) {
t.Helper()
for _, r := range results {
if r.Path == path {
return
}
}
t.Errorf("expected %q in results, got %v", path, resultPaths(results))
}
func TestEngine_SearchFindsKnownFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "src/main.go", "package main")
createFile(t, dir, "src/handler.go", "package main")
createFile(t, dir, "README.md", "# hello")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "main.go", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.NotEmpty(t, results, "expected to find main.go")
requireResultHasPath(t, results, "src/main.go")
}
func TestEngine_SearchFuzzyMatch(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "src/controllers/user_handler.go", "package controllers")
createFile(t, dir, "src/models/user.go", "package models")
createFile(t, dir, "docs/api.md", "# API")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
// "handler" should match "user_handler.go".
results, err := eng.Search(ctx, "handler", filefinder.DefaultSearchOptions())
require.NoError(t, err)
// The query is a subsequence of "user_handler.go" so it
// should appear somewhere in the results.
requireResultHasPath(t, results, "src/controllers/user_handler.go")
}
func TestEngine_IndexPicksUpNewFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "existing.txt", "hello")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
createFile(t, dir, "newfile_unique.txt", "world")
require.Eventually(t, func() bool {
results, sErr := eng.Search(ctx, "newfile_unique", filefinder.DefaultSearchOptions())
if sErr != nil {
return false
}
for _, r := range results {
if r.Path == "newfile_unique.txt" {
return true
}
}
return false
}, testutil.WaitShort, testutil.IntervalFast, "expected newfile_unique.txt to appear via watcher")
}
func TestEngine_IndexRemovesDeletedFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "deleteme_unique.txt", "goodbye")
createFile(t, dir, "keeper.txt", "stay")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "deleteme_unique", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.NotEmpty(t, results, "expected to find deleteme_unique.txt initially")
require.NoError(t, os.Remove(filepath.Join(dir, "deleteme_unique.txt")))
require.Eventually(t, func() bool {
results, sErr := eng.Search(ctx, "deleteme_unique", filefinder.DefaultSearchOptions())
if sErr != nil {
return false
}
for _, r := range results {
if r.Path == "deleteme_unique.txt" {
return false // still found
}
}
return true
}, testutil.WaitShort, testutil.IntervalFast, "expected deleteme_unique.txt to disappear after removal")
}
func TestEngine_MultipleRoots(t *testing.T) {
t.Parallel()
dir1 := t.TempDir()
dir2 := t.TempDir()
createFile(t, dir1, "alpha_unique.go", "package alpha")
createFile(t, dir2, "beta_unique.go", "package beta")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir1))
require.NoError(t, eng.AddRoot(ctx, dir2))
results, err := eng.Search(ctx, "alpha_unique", filefinder.DefaultSearchOptions())
require.NoError(t, err)
requireResultHasPath(t, results, "alpha_unique.go")
results, err = eng.Search(ctx, "beta_unique", filefinder.DefaultSearchOptions())
require.NoError(t, err)
requireResultHasPath(t, results, "beta_unique.go")
}
func TestEngine_EmptyQueryReturnsEmpty(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "something.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.Empty(t, results, "empty query should return no results")
}
func TestEngine_CloseIsClean(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "file.txt", "data")
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := context.Background()
eng := filefinder.NewEngine(logger)
require.NoError(t, eng.AddRoot(ctx, dir))
require.NoError(t, eng.Close())
_, err := eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
require.Error(t, err)
}
func TestEngine_AddRootIdempotent(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "file.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
require.NoError(t, eng.AddRoot(ctx, dir))
snapLen := filefinder.EngineSnapLen(eng)
require.Equal(t, 1, snapLen, "expected exactly one root after duplicate add")
}
func TestEngine_RemoveRoot(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "file.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.NotEmpty(t, results)
require.NoError(t, eng.RemoveRoot(dir))
results, err = eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.Empty(t, results)
}
func TestEngine_Rebuild(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "original.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
createFile(t, dir, "sneaky_rebuild.txt", "hidden")
require.NoError(t, eng.Rebuild(ctx, dir))
results, err := eng.Search(ctx, "sneaky_rebuild", filefinder.DefaultSearchOptions())
require.NoError(t, err)
requireResultHasPath(t, results, "sneaky_rebuild.txt")
}
// createFile creates a file (and parent dirs) at relPath under dir.
func createFile(t *testing.T, dir, relPath, content string) {
t.Helper()
full := filepath.Join(dir, relPath)
require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755))
require.NoError(t, os.WriteFile(full, []byte(content), 0o600))
}
func resultPaths(results []filefinder.Result) []string {
paths := make([]string, len(results))
for i, r := range results {
paths[i] = r.Path
}
sort.Strings(paths)
return paths
}
+85
View File
@@ -0,0 +1,85 @@
package filefinder
// Test helpers that need internal access.
// MakeTestSnapshot builds a Snapshot from a list of paths. Useful for
// query-level tests that don't need a real filesystem.
func MakeTestSnapshot(paths []string) *Snapshot {
idx := NewIndex()
for _, p := range paths {
idx.Add(p, 0)
}
return idx.Snapshot()
}
// BuildTestIndex walks root and returns a populated Index, the same
// way Engine.AddRoot does but without starting a watcher.
func BuildTestIndex(root string) (*Index, error) {
return walkRoot(root)
}
// IndexIsDeleted reports whether the document at id is tombstoned.
func IndexIsDeleted(idx *Index, id uint32) bool {
return idx.deleted[id]
}
// IndexByGramLen returns the number of entries in the trigram index.
func IndexByGramLen(idx *Index) int {
return len(idx.byGram)
}
// IndexByPrefix1Len returns the number of posting-list entries for
// the given single-byte prefix.
func IndexByPrefix1Len(idx *Index, b byte) int {
return len(idx.byPrefix1[b])
}
// SnapshotCount returns the number of documents in a Snapshot.
func SnapshotCount(snap *Snapshot) int {
return len(snap.docs)
}
// EngineSnapLen returns the number of root snapshots currently held
// by the engine, or -1 if the pointer is nil.
func EngineSnapLen(eng *Engine) int {
p := eng.snap.Load()
if p == nil {
return -1
}
return len(*p)
}
// DefaultScoreParamsForTest exposes defaultScoreParams for tests.
var DefaultScoreParamsForTest = defaultScoreParams
// ScoreParamsForTest is a type alias for scoreParams.
type ScoreParamsForTest = scoreParams
// Exported aliases for internal functions used in tests.
var (
NewQueryPlanForTest = newQueryPlan
SearchSnapshotForTest = searchSnapshot
IntersectSortedForTest = intersectSorted
IntersectAllForTest = intersectAll
MergeAndScoreForTest = mergeAndScore
NormalizeQueryForTest = normalizeQuery
NormalizePathBytesForTest = normalizePathBytes
ExtractTrigramsForTest = extractTrigrams
ExtractBasenameForTest = extractBasename
ExtractSegmentsForTest = extractSegments
Prefix1ForTest = prefix1
Prefix2ForTest = prefix2
IsSubsequenceForTest = isSubsequence
LongestContiguousMatchForTest = longestContiguousMatch
IsBoundaryForTest = isBoundary
CountBoundaryHitsForTest = countBoundaryHits
EqualFoldASCIIForTest = equalFoldASCII
ScorePathForTest = scorePath
PackTrigramForTest = packTrigram
)
// Type aliases for internal types used in tests.
type (
CandidateForTest = candidate
QueryPlanForTest = queryPlan
)
+299
View File
@@ -0,0 +1,299 @@
package filefinder
import (
"container/heap"
"slices"
"strings"
)
type candidate struct {
DocID uint32
Path string
BaseOff int
BaseLen int
Depth int
Flags uint16
}
// Result is a scored search result returned to callers.
type Result struct {
Path string
Score float32
IsDir bool
}
type queryPlan struct {
Original string
Normalized string
Tokens [][]byte
Trigrams []uint32
IsShort bool
HasSlash bool
BasenameQ []byte
DirTokens [][]byte
}
func newQueryPlan(q string) *queryPlan {
norm := normalizeQuery(q)
p := &queryPlan{Original: q, Normalized: norm}
if len(norm) == 0 {
p.IsShort = true
return p
}
raw := strings.ReplaceAll(norm, "/", " ")
parts := strings.Fields(raw)
p.HasSlash = strings.ContainsRune(norm, '/')
for _, part := range parts {
p.Tokens = append(p.Tokens, []byte(part))
}
if len(p.Tokens) > 0 {
p.BasenameQ = p.Tokens[len(p.Tokens)-1]
if len(p.Tokens) > 1 {
p.DirTokens = p.Tokens[:len(p.Tokens)-1]
}
}
p.IsShort = true
for _, tok := range p.Tokens {
if len(tok) >= 3 {
p.IsShort = false
break
}
}
if !p.IsShort {
p.Trigrams = extractQueryTrigrams(p.Tokens)
}
return p
}
func extractQueryTrigrams(tokens [][]byte) []uint32 {
seen := make(map[uint32]struct{})
for _, tok := range tokens {
if len(tok) < 3 {
continue
}
for i := 0; i <= len(tok)-3; i++ {
seen[packTrigram(tok[i], tok[i+1], tok[i+2])] = struct{}{}
}
}
if len(seen) == 0 {
return nil
}
result := make([]uint32, 0, len(seen))
for g := range seen {
result = append(result, g)
}
return result
}
func packTrigram(a, b, c byte) uint32 {
return uint32(toLowerASCII(a))<<16 | uint32(toLowerASCII(b))<<8 | uint32(toLowerASCII(c))
}
// searchSnapshot runs the full search pipeline against a single
// root snapshot: it selects a strategy (prefix, trigram, or
// fuzzy fallback) based on query length, retrieves candidate
// doc IDs, and converts them into candidate structs.
func searchSnapshot(plan *queryPlan, snap *Snapshot, limit int) []candidate {
if snap == nil || len(snap.docs) == 0 || len(plan.Normalized) == 0 {
return nil
}
var ids []uint32
if plan.IsShort {
ids = searchShort(plan, snap)
} else {
ids = searchTrigrams(plan, snap)
if len(ids) == 0 && len(plan.BasenameQ) > 0 {
ids = searchFuzzyFallback(plan, snap)
}
}
if len(ids) == 0 {
return nil
}
cands := make([]candidate, 0, min(len(ids), limit))
for _, id := range ids {
if snap.deleted[id] || int(id) >= len(snap.docs) {
continue
}
d := snap.docs[id]
cands = append(cands, candidate{
DocID: id, Path: d.path, BaseOff: d.baseOff,
BaseLen: d.baseLen, Depth: d.depth, Flags: d.flags,
})
if len(cands) >= limit {
break
}
}
return cands
}
func searchShort(plan *queryPlan, snap *Snapshot) []uint32 {
if len(plan.BasenameQ) == 0 {
return nil
}
if len(plan.BasenameQ) >= 2 {
if ids := snap.byPrefix2[prefix2(plan.BasenameQ)]; len(ids) > 0 {
return ids
}
}
return snap.byPrefix1[prefix1(plan.BasenameQ)]
}
func searchTrigrams(plan *queryPlan, snap *Snapshot) []uint32 {
if len(plan.Trigrams) == 0 {
return nil
}
lists := make([][]uint32, 0, len(plan.Trigrams))
for _, g := range plan.Trigrams {
ids, ok := snap.byGram[g]
if !ok || len(ids) == 0 {
return nil
}
lists = append(lists, ids)
}
return intersectAll(lists)
}
func searchFuzzyFallback(plan *queryPlan, snap *Snapshot) []uint32 {
if len(plan.BasenameQ) == 0 {
return nil
}
bucket := snap.byPrefix1[prefix1(plan.BasenameQ)]
if len(bucket) == 0 {
return searchSubsequenceScan(plan, snap, 5000)
}
var ids []uint32
for _, id := range bucket {
if snap.deleted[id] || int(id) >= len(snap.docs) {
continue
}
if isSubsequence([]byte(snap.docs[id].path), plan.BasenameQ) {
ids = append(ids, id)
}
}
if len(ids) == 0 {
return searchSubsequenceScan(plan, snap, 5000)
}
return ids
}
func searchSubsequenceScan(plan *queryPlan, snap *Snapshot, maxCheck int) []uint32 {
if len(plan.BasenameQ) == 0 {
return nil
}
var ids []uint32
checked := 0
for id := 0; id < len(snap.docs) && checked < maxCheck; id++ {
uid := uint32(id) //nolint:gosec // Snapshot count is bounded well below 2^32.
if snap.deleted[uid] {
continue
}
checked++
if isSubsequence([]byte(snap.docs[id].path), plan.BasenameQ) {
ids = append(ids, uid)
}
}
return ids
}
func intersectSorted(a, b []uint32) []uint32 {
if len(a) == 0 || len(b) == 0 {
return nil
}
var result []uint32
ai, bi := 0, 0
for ai < len(a) && bi < len(b) {
switch {
case a[ai] < b[bi]:
ai++
case a[ai] > b[bi]:
bi++
default:
result = append(result, a[ai])
ai++
bi++
}
}
return result
}
func intersectAll(lists [][]uint32) []uint32 {
if len(lists) == 0 {
return nil
}
if len(lists) == 1 {
return lists[0]
}
slices.SortFunc(lists, func(a, b []uint32) int { return len(a) - len(b) })
result := lists[0]
for i := 1; i < len(lists) && len(result) > 0; i++ {
result = intersectSorted(result, lists[i])
}
return result
}
func mergeAndScore(cands []candidate, plan *queryPlan, params scoreParams, topK int) []Result {
if topK <= 0 || len(cands) == 0 {
return nil
}
query := []byte(plan.Normalized)
h := &resultHeap{}
heap.Init(h)
for i := range cands {
c := &cands[i]
s := scorePath([]byte(c.Path), c.BaseOff, c.BaseLen, c.Depth, query, plan.Tokens, params)
if s <= 0 {
continue
}
// DirTokenHit is applied here rather than in scorePath because
// it depends on the query plan's directory tokens, which are
// split from the full query during planning. scorePath operates
// on raw query bytes without knowledge of token boundaries.
if len(plan.DirTokens) > 0 {
segments := extractSegments([]byte(c.Path))
for _, dt := range plan.DirTokens {
for _, seg := range segments {
if equalFoldASCII(seg, dt) {
s += params.DirTokenHit
break
}
}
}
}
r := Result{Path: c.Path, Score: s, IsDir: c.Flags == uint16(FlagDir)}
if h.Len() < topK {
heap.Push(h, r)
} else if s > (*h)[0].Score {
(*h)[0] = r
heap.Fix(h, 0)
}
}
n := h.Len()
results := make([]Result, n)
for i := n - 1; i >= 0; i-- {
v := heap.Pop(h)
if r, ok := v.(Result); ok {
results[i] = r
}
}
return results
}
type resultHeap []Result
func (h resultHeap) Len() int { return len(h) }
func (h resultHeap) Less(i, j int) bool { return h[i].Score < h[j].Score }
func (h resultHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *resultHeap) Push(x interface{}) {
r, ok := x.(Result)
if ok {
*h = append(*h, r)
}
}
func (h *resultHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[:n-1]
return x
}
+343
View File
@@ -0,0 +1,343 @@
package filefinder_test
import (
"slices"
"testing"
"github.com/coder/coder/v2/agent/filefinder"
)
func TestNewQueryPlan(t *testing.T) {
t.Parallel()
tests := []struct {
name string
query string
wantNorm string
wantShort bool
wantSlash bool
wantBase string
wantTokens []string
wantDirTok []string
wantTriCnt int // -1 to skip check
}{
{"Simple", "foo", "foo", false, false, "foo", []string{"foo"}, nil, 1},
{"MultiToken", "foo bar", "foo bar", false, false, "bar", []string{"foo", "bar"}, []string{"foo"}, -1},
{"Slash", "internal/foo", "internal/foo", false, true, "foo", []string{"internal", "foo"}, []string{"internal"}, -1},
{"SingleChar", "a", "a", true, false, "a", []string{"a"}, nil, 0},
{"TwoChars", "ab", "ab", true, false, "ab", []string{"ab"}, nil, -1},
{"ThreeChars", "abc", "abc", false, false, "abc", []string{"abc"}, nil, 1},
{"DotPrefix", ".go", ".go", false, false, ".go", []string{".go"}, nil, -1},
{"UpperCase", "FOO", "foo", false, false, "foo", []string{"foo"}, nil, -1},
{"Empty", "", "", true, false, "", nil, nil, -1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest(tt.query)
if plan.Normalized != tt.wantNorm {
t.Errorf("normalized = %q, want %q", plan.Normalized, tt.wantNorm)
}
if plan.IsShort != tt.wantShort {
t.Errorf("isShort = %v, want %v", plan.IsShort, tt.wantShort)
}
if plan.HasSlash != tt.wantSlash {
t.Errorf("hasSlash = %v, want %v", plan.HasSlash, tt.wantSlash)
}
if string(plan.BasenameQ) != tt.wantBase {
t.Errorf("basenameQ = %q, want %q", plan.BasenameQ, tt.wantBase)
}
if tt.wantTokens == nil {
if len(plan.Tokens) != 0 {
t.Errorf("expected 0 tokens, got %d", len(plan.Tokens))
}
} else {
if len(plan.Tokens) != len(tt.wantTokens) {
t.Fatalf("tokens len = %d, want %d", len(plan.Tokens), len(tt.wantTokens))
}
for i, tok := range plan.Tokens {
if string(tok) != tt.wantTokens[i] {
t.Errorf("tokens[%d] = %q, want %q", i, tok, tt.wantTokens[i])
}
}
}
if tt.wantDirTok != nil {
if len(plan.DirTokens) != len(tt.wantDirTok) {
t.Fatalf("dirTokens len = %d, want %d", len(plan.DirTokens), len(tt.wantDirTok))
}
for i, tok := range plan.DirTokens {
if string(tok) != tt.wantDirTok[i] {
t.Errorf("dirTokens[%d] = %q, want %q", i, tok, tt.wantDirTok[i])
}
}
}
if tt.wantTriCnt >= 0 && len(plan.Trigrams) != tt.wantTriCnt {
t.Errorf("trigram count = %d, want %d", len(plan.Trigrams), tt.wantTriCnt)
}
})
}
// ThreeChars: verify the actual trigram value.
plan := filefinder.NewQueryPlanForTest("abc")
if want := filefinder.PackTrigramForTest('a', 'b', 'c'); plan.Trigrams[0] != want {
t.Errorf("trigram = %x, want %x", plan.Trigrams[0], want)
}
// ShortMultiToken: both tokens < 3 chars so isShort should be true.
plan = filefinder.NewQueryPlanForTest("ab cd")
if !plan.IsShort {
t.Error("expected isShort=true when all tokens < 3 chars")
}
// One token >= 3 chars, so isShort should be false.
plan = filefinder.NewQueryPlanForTest("ab cde")
if plan.IsShort {
t.Error("expected isShort=false when any token >= 3 chars")
}
}
func requireCandHasPath(t *testing.T, cands []filefinder.CandidateForTest, path string) {
t.Helper()
for _, c := range cands {
if c.Path == path {
return
}
}
t.Errorf("expected to find %q in candidates", path)
}
func TestSearchSnapshot_TrigramMatch(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/router.go", "lib/utils.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected at least 1 candidate for 'handler'")
}
requireCandHasPath(t, cands, "src/handler.go")
}
func TestSearchSnapshot_ShortQuery(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"foo.go", "bar.go", "fab.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("fo"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected at least 1 candidate for 'fo'")
}
requireCandHasPath(t, cands, "foo.go")
}
func TestSearchSnapshot_FuzzyFallback(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/router.go", "lib/utils.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("hndlr"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected fuzzy fallback to find 'handler.go' for query 'hndlr'")
}
requireCandHasPath(t, cands, "src/handler.go")
}
func TestSearchSnapshot_FuzzyFallbackNoFirstCharMatch(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/xylophone.go", "lib/extra.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("xylo"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected at least 1 candidate for 'xylo'")
}
requireCandHasPath(t, cands, "src/xylophone.go")
}
func TestSearchSnapshot_NilSnapshot(t *testing.T) {
t.Parallel()
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("foo"), nil, 100)
if cands != nil {
t.Errorf("expected nil for nil snapshot, got %v", cands)
}
}
func TestSearchSnapshot_EmptyQuery(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"foo.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest(""), snap, 100)
if cands != nil {
t.Errorf("expected nil for empty query, got %v", cands)
}
}
func TestSearchSnapshot_DeletedDocsExcluded(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("handler.go", 0)
idx.Remove("handler.go")
snap := idx.Snapshot()
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 100)
for _, c := range cands {
if c.Path == "handler.go" {
t.Error("deleted doc should not appear in results")
}
}
}
func TestSearchSnapshot_Limit(t *testing.T) {
t.Parallel()
paths := make([]string, 50)
for i := range paths {
paths[i] = "handler" + string(rune('a'+i%26)) + ".go"
}
snap := filefinder.MakeTestSnapshot(paths)
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 3)
if len(cands) > 3 {
t.Errorf("expected at most 3 candidates, got %d", len(cands))
}
}
func TestIntersectSorted(t *testing.T) {
t.Parallel()
tests := []struct {
name string
a, b []uint32
want []uint32
}{
{"both empty", nil, nil, nil},
{"a empty", nil, []uint32{1, 2}, nil},
{"b empty", []uint32{1, 2}, nil, nil},
{"no overlap", []uint32{1, 3, 5}, []uint32{2, 4, 6}, nil},
{"full overlap", []uint32{1, 2, 3}, []uint32{1, 2, 3}, []uint32{1, 2, 3}},
{"partial overlap", []uint32{1, 2, 3, 5}, []uint32{2, 4, 5}, []uint32{2, 5}},
{"single match", []uint32{1, 2, 3}, []uint32{2}, []uint32{2}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.IntersectSortedForTest(tt.a, tt.b)
if len(tt.want) == 0 {
if len(got) != 0 {
t.Errorf("got %v, want empty/nil", got)
}
return
}
if !slices.Equal(got, tt.want) {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestIntersectAll(t *testing.T) {
t.Parallel()
t.Run("empty", func(t *testing.T) {
t.Parallel()
if got := filefinder.IntersectAllForTest(nil); got != nil {
t.Errorf("got %v, want nil", got)
}
})
t.Run("single", func(t *testing.T) {
t.Parallel()
if got := filefinder.IntersectAllForTest([][]uint32{{1, 2, 3}}); len(got) != 3 {
t.Fatalf("len = %d, want 3", len(got))
}
})
t.Run("multiple", func(t *testing.T) {
t.Parallel()
got := filefinder.IntersectAllForTest([][]uint32{{1, 2, 3, 4, 5}, {2, 3, 5}, {3, 5, 7}})
if !slices.Equal(got, []uint32{3, 5}) {
t.Errorf("got %v, want [3 5]", got)
}
})
t.Run("no overlap", func(t *testing.T) {
t.Parallel()
if got := filefinder.IntersectAllForTest([][]uint32{{1, 2}, {3, 4}}); got != nil {
t.Errorf("got %v, want nil", got)
}
})
}
func TestMergeAndScore_SortedDescending(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("foo")
params := filefinder.DefaultScoreParamsForTest()
cands := []filefinder.CandidateForTest{
{DocID: 0, Path: "a/b/c/d/e/foo", BaseOff: 10, BaseLen: 3, Depth: 5},
{DocID: 1, Path: "src/foo", BaseOff: 4, BaseLen: 3, Depth: 1},
{DocID: 2, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0},
}
results := filefinder.MergeAndScoreForTest(cands, plan, params, 10)
if len(results) == 0 {
t.Fatal("expected non-empty results")
}
for i := 1; i < len(results); i++ {
if results[i].Score > results[i-1].Score {
t.Errorf("results not sorted: [%d].Score=%f > [%d].Score=%f",
i, results[i].Score, i-1, results[i-1].Score)
}
}
}
func TestMergeAndScore_TopKLimit(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("f")
params := filefinder.DefaultScoreParamsForTest()
var cands []filefinder.CandidateForTest
for i := range 20 {
p := "f" + string(rune('a'+i))
cands = append(cands, filefinder.CandidateForTest{DocID: uint32(i), Path: p, BaseOff: 0, BaseLen: len(p), Depth: 0}) //nolint:gosec // test index is tiny
}
if results := filefinder.MergeAndScoreForTest(cands, plan, params, 5); len(results) != 5 {
t.Errorf("expected 5 results, got %d", len(results))
}
}
func TestMergeAndScore_ZeroTopK(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("foo")
cands := []filefinder.CandidateForTest{{DocID: 0, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0}}
if results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 0); len(results) != 0 {
t.Errorf("expected 0 results for topK=0, got %d", len(results))
}
}
func TestMergeAndScore_NoMatchCandidatesDropped(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("xyz")
cands := []filefinder.CandidateForTest{
{DocID: 0, Path: "abc", BaseOff: 0, BaseLen: 3, Depth: 0},
{DocID: 1, Path: "def", BaseOff: 0, BaseLen: 3, Depth: 0},
}
if results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 10); len(results) != 0 {
t.Errorf("expected 0 results for non-matching candidates, got %d", len(results))
}
}
func TestMergeAndScore_IsDirFlag(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("foo")
cands := []filefinder.CandidateForTest{
{DocID: 0, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0, Flags: uint16(filefinder.FlagDir)},
}
results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 10)
if len(results) != 1 {
t.Fatalf("expected 1 result, got %d", len(results))
}
if !results[0].IsDir {
t.Error("expected IsDir=true for FlagDir candidate")
}
}
func TestMergeAndScore_EmptyCandidates(t *testing.T) {
t.Parallel()
if results := filefinder.MergeAndScoreForTest(nil, filefinder.NewQueryPlanForTest("foo"), filefinder.DefaultScoreParamsForTest(), 10); len(results) != 0 {
t.Errorf("expected 0 results for nil candidates, got %d", len(results))
}
}
func TestSearchSnapshot_FuzzyFallbackEndToEnd(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/middleware.go", "pkg/config.go"})
plan := filefinder.NewQueryPlanForTest("hndlr")
results := filefinder.MergeAndScoreForTest(filefinder.SearchSnapshotForTest(plan, snap, 100), plan, filefinder.DefaultScoreParamsForTest(), 10)
if len(results) == 0 {
t.Fatal("expected fuzzy fallback to produce scored results for 'hndlr'")
}
if results[0].Path != "src/handler.go" {
t.Errorf("expected top result 'src/handler.go', got %q", results[0].Path)
}
}
+288
View File
@@ -0,0 +1,288 @@
package filefinder
import "slices"
func toLowerASCII(b byte) byte {
if b >= 'A' && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
func normalizeQuery(q string) string {
b := make([]byte, 0, len(q))
prevSpace := true
for i := 0; i < len(q); i++ {
c := q[i]
if c == '\\' {
c = '/'
}
c = toLowerASCII(c)
if c == ' ' {
if prevSpace {
continue
}
prevSpace = true
} else {
prevSpace = false
}
b = append(b, c)
}
if len(b) > 0 && b[len(b)-1] == ' ' {
b = b[:len(b)-1]
}
return string(b)
}
func normalizePathBytes(p []byte) []byte {
j := 0
prevSlash := false
for i := 0; i < len(p); i++ {
c := p[i]
if c == '\\' {
c = '/'
}
c = toLowerASCII(c)
if c == '/' {
if prevSlash {
continue
}
prevSlash = true
} else {
prevSlash = false
}
p[j] = c
j++
}
return p[:j]
}
// extractTrigrams returns deduplicated, sorted trigrams (three-byte
// subsequences) from s. Trigrams are the primary index key: a
// document matches a query only if every query trigram appears in
// the document, giving O(1) candidate filtering per trigram.
func extractTrigrams(s []byte) []uint32 {
if len(s) < 3 {
return nil
}
seen := make(map[uint32]struct{}, len(s))
for i := 0; i <= len(s)-3; i++ {
b0 := toLowerASCII(s[i])
b1 := toLowerASCII(s[i+1])
b2 := toLowerASCII(s[i+2])
gram := uint32(b0)<<16 | uint32(b1)<<8 | uint32(b2)
seen[gram] = struct{}{}
}
result := make([]uint32, 0, len(seen))
for g := range seen {
result = append(result, g)
}
slices.Sort(result)
return result
}
func extractBasename(path []byte) (offset int, length int) {
end := len(path)
if end > 0 && path[end-1] == '/' {
end--
}
if end == 0 {
return 0, 0
}
i := end - 1
for i >= 0 && path[i] != '/' {
i--
}
start := i + 1
return start, end - start
}
func extractSegments(path []byte) [][]byte {
var segments [][]byte
start := 0
for i := 0; i <= len(path); i++ {
if i == len(path) || path[i] == '/' {
if i > start {
segments = append(segments, path[start:i])
}
start = i + 1
}
}
return segments
}
func prefix1(name []byte) byte {
if len(name) == 0 {
return 0
}
return toLowerASCII(name[0])
}
func prefix2(name []byte) uint16 {
if len(name) == 0 {
return 0
}
hi := uint16(toLowerASCII(name[0])) << 8
if len(name) < 2 {
return hi
}
return hi | uint16(toLowerASCII(name[1]))
}
// scoreParams controls the weights for each scoring signal.
type scoreParams struct {
BasenameMatch float32
BasenamePrefix float32
ExactSegment float32
BoundaryHit float32
ContiguousRun float32
DirTokenHit float32
DepthPenalty float32
LengthPenalty float32
}
func defaultScoreParams() scoreParams {
return scoreParams{
BasenameMatch: 6.0,
BasenamePrefix: 3.5,
ExactSegment: 2.5,
BoundaryHit: 1.8,
ContiguousRun: 1.2,
DirTokenHit: 0.4,
DepthPenalty: 0.08,
LengthPenalty: 0.01,
}
}
func isSubsequence(haystack, needle []byte) bool {
if len(needle) == 0 {
return true
}
ni := 0
for _, hb := range haystack {
if toLowerASCII(hb) == toLowerASCII(needle[ni]) {
ni++
if ni == len(needle) {
return true
}
}
}
return false
}
func longestContiguousMatch(haystack, needle []byte) int {
if len(needle) == 0 || len(haystack) == 0 {
return 0
}
best := 0
ni := 0
run := 0
for _, hb := range haystack {
if ni < len(needle) && toLowerASCII(hb) == toLowerASCII(needle[ni]) {
run++
ni++
if run > best {
best = run
}
} else {
run = 0
ni = 0
if ni < len(needle) && toLowerASCII(hb) == toLowerASCII(needle[ni]) {
run = 1
ni = 1
if run > best {
best = run
}
}
}
}
return best
}
func isBoundary(b byte) bool {
return b == '/' || b == '.' || b == '_' || b == '-'
}
func countBoundaryHits(path []byte, query []byte) int {
if len(query) == 0 || len(path) == 0 {
return 0
}
hits := 0
qi := 0
for pi := 0; pi < len(path) && qi < len(query); pi++ {
atBoundary := pi == 0 || isBoundary(path[pi-1])
if atBoundary && toLowerASCII(path[pi]) == toLowerASCII(query[qi]) {
hits++
qi++
}
}
return hits
}
func equalFoldASCII(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if toLowerASCII(a[i]) != toLowerASCII(b[i]) {
return false
}
}
return true
}
func hasPrefixFoldASCII(haystack, prefix []byte) bool {
if len(prefix) > len(haystack) {
return false
}
for i := range prefix {
if toLowerASCII(haystack[i]) != toLowerASCII(prefix[i]) {
return false
}
}
return true
}
// scorePath computes a relevance score for a candidate path
// against a query. The score combines several signals:
// basename match, basename prefix, exact segment match,
// word-boundary hits, longest contiguous run, and penalties
// for depth and length. A return value of 0 means no match
// (the query is not a subsequence of the path).
func scorePath(
path []byte,
baseOff int,
baseLen int,
depth int,
query []byte,
queryTokens [][]byte,
params scoreParams,
) float32 {
if !isSubsequence(path, query) {
return 0
}
var score float32
basename := path[baseOff : baseOff+baseLen]
if isSubsequence(basename, query) {
score += params.BasenameMatch
}
if hasPrefixFoldASCII(basename, query) {
score += params.BasenamePrefix
}
segments := extractSegments(path)
for _, token := range queryTokens {
for _, seg := range segments {
if equalFoldASCII(seg, token) {
score += params.ExactSegment
break
}
}
}
bh := countBoundaryHits(path, query)
score += float32(bh) * params.BoundaryHit
lcm := longestContiguousMatch(path, query)
score += float32(lcm) * params.ContiguousRun
score -= float32(depth) * params.DepthPenalty
score -= float32(len(path)) * params.LengthPenalty
return score
}
+388
View File
@@ -0,0 +1,388 @@
package filefinder_test
import (
"slices"
"testing"
"github.com/coder/coder/v2/agent/filefinder"
)
func TestNormalizeQuery(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{"empty", "", ""},
{"leading and trailing spaces", " hello ", "hello"},
{"multiple internal spaces", "foo bar baz", "foo bar baz"},
{"uppercase to lower", "FooBar", "foobar"},
{"backslash to slash", `foo\bar\baz`, "foo/bar/baz"},
{"mixed case and spaces", " Hello World ", "hello world"},
{"unicode passthrough", "héllo wörld", "héllo wörld"},
{"only spaces", " ", ""},
{"single char", "A", "a"},
{"slashes preserved", "/foo/bar/", "/foo/bar/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.NormalizeQueryForTest(tt.input)
if got != tt.want {
t.Errorf("normalizeQuery(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestExtractTrigrams(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want []uint32
}{
{"too short", "ab", nil},
{"exactly three bytes", "abc", []uint32{uint32('a')<<16 | uint32('b')<<8 | uint32('c')}},
{"case insensitive", "ABC", []uint32{uint32('a')<<16 | uint32('b')<<8 | uint32('c')}},
{"deduplication", "aaaa", []uint32{uint32('a')<<16 | uint32('a')<<8 | uint32('a')}},
{"four bytes produces two trigrams", "abcd", []uint32{
uint32('a')<<16 | uint32('b')<<8 | uint32('c'),
uint32('b')<<16 | uint32('c')<<8 | uint32('d'),
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.ExtractTrigramsForTest([]byte(tt.input))
if !slices.Equal(got, tt.want) {
t.Errorf("extractTrigrams(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestExtractBasename(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
wantOff int
wantName string
}{
{"full path", "/foo/bar/baz.go", 9, "baz.go"},
{"bare filename", "baz.go", 0, "baz.go"},
{"trailing slash", "/a/b/", 3, "b"},
{"root slash", "/", 0, ""},
{"empty", "", 0, ""},
{"single dir with slash", "/foo", 1, "foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
off, length := filefinder.ExtractBasenameForTest([]byte(tt.path))
if off != tt.wantOff {
t.Errorf("extractBasename(%q) offset = %d, want %d", tt.path, off, tt.wantOff)
}
gotName := string([]byte(tt.path)[off : off+length])
if gotName != tt.wantName {
t.Errorf("extractBasename(%q) name = %q, want %q", tt.path, gotName, tt.wantName)
}
})
}
}
func TestExtractSegments(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
want []string
}{
{"absolute path", "/foo/bar/baz", []string{"foo", "bar", "baz"}},
{"relative path", "foo/bar", []string{"foo", "bar"}},
{"trailing slash", "/a/b/", []string{"a", "b"}},
{"multiple slashes", "//a///b//", []string{"a", "b"}},
{"empty", "", nil},
{"single segment", "foo", []string{"foo"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.ExtractSegmentsForTest([]byte(tt.path))
if len(got) != len(tt.want) {
t.Fatalf("extractSegments(%q) got %d segments, want %d", tt.path, len(got), len(tt.want))
}
for i := range got {
if string(got[i]) != tt.want[i] {
t.Errorf("extractSegments(%q)[%d] = %q, want %q", tt.path, i, got[i], tt.want[i])
}
}
})
}
}
func TestPrefix1(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want byte
}{
{"lowercase", "foo", 'f'},
{"uppercase", "Foo", 'f'},
{"empty", "", 0},
{"digit", "1abc", '1'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.Prefix1ForTest([]byte(tt.in))
if got != tt.want {
t.Errorf("prefix1(%q) = %d (%c), want %d (%c)", tt.in, got, got, tt.want, tt.want)
}
})
}
}
func TestPrefix2(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want uint16
}{
{"two chars", "ab", uint16('a')<<8 | uint16('b')},
{"uppercase", "AB", uint16('a')<<8 | uint16('b')},
{"single char", "A", uint16('a') << 8},
{"empty", "", 0},
{"longer string", "Hello", uint16('h')<<8 | uint16('e')},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.Prefix2ForTest([]byte(tt.in))
if got != tt.want {
t.Errorf("prefix2(%q) = %d, want %d", tt.in, got, tt.want)
}
})
}
}
func TestNormalizePathBytes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{"backslash to slash", `C:\Users\test`, "c:/users/test"},
{"collapse slashes", "//foo///bar//", "/foo/bar/"},
{"lowercase", "FooBar", "foobar"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
buf := []byte(tt.input)
got := string(filefinder.NormalizePathBytesForTest(buf))
if got != tt.want {
t.Errorf("normalizePathBytes(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestIsSubsequence(t *testing.T) {
t.Parallel()
tests := []struct {
name string
haystack string
needle string
want bool
}{
{"empty needle", "anything", "", true},
{"empty both", "", "", true},
{"empty haystack", "", "a", false},
{"exact match", "abc", "abc", true},
{"scattered", "axbycz", "abc", true},
{"prefix", "abcdef", "abc", true},
{"suffix", "xyzabc", "abc", true},
{"case insensitive", "AbCdEf", "ace", true},
{"case insensitive reverse", "abcdef", "ACE", true},
{"no match", "abcdef", "xyz", false},
{"partial match", "abcdef", "abz", false},
{"longer needle", "ab", "abc", false},
{"single char match", "hello", "l", true},
{"single char no match", "hello", "z", false},
{"path like", "src/internal/foo.go", "sif", true},
{"path like no match", "src/internal/foo.go", "zzz", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.IsSubsequenceForTest([]byte(tt.haystack), []byte(tt.needle))
if got != tt.want {
t.Errorf("isSubsequence(%q, %q) = %v, want %v", tt.haystack, tt.needle, got, tt.want)
}
})
}
}
func TestLongestContiguousMatch(t *testing.T) {
t.Parallel()
tests := []struct {
name string
haystack string
needle string
want int
}{
{"empty needle", "abc", "", 0},
{"empty haystack", "", "abc", 0},
{"full match", "abc", "abc", 3},
{"prefix match", "abcdef", "abc", 3},
{"middle match", "xxabcyy", "abc", 3},
{"suffix match", "xxabc", "abc", 3},
{"partial", "axbc", "abc", 1},
{"scattered no contiguous", "axbxcx", "abc", 1},
{"case insensitive", "ABCdef", "abc", 3},
{"no match", "xyz", "abc", 0},
{"single char", "abc", "b", 1},
{"repeated", "aababc", "abc", 3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.LongestContiguousMatchForTest([]byte(tt.haystack), []byte(tt.needle))
if got != tt.want {
t.Errorf("longestContiguousMatch(%q, %q) = %d, want %d", tt.haystack, tt.needle, got, tt.want)
}
})
}
}
func TestIsBoundary(t *testing.T) {
t.Parallel()
for _, b := range []byte{'/', '.', '_', '-'} {
if !filefinder.IsBoundaryForTest(b) {
t.Errorf("isBoundary(%q) = false, want true", b)
}
}
for _, b := range []byte{'a', 'Z', '0', ' ', '('} {
if filefinder.IsBoundaryForTest(b) {
t.Errorf("isBoundary(%q) = true, want false", b)
}
}
}
func TestCountBoundaryHits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
query string
want int
}{
{"start of string", "foo/bar", "f", 1},
{"after slash", "foo/bar", "fb", 2},
{"after dot", "foo.bar", "fb", 2},
{"after underscore", "foo_bar", "fb", 2},
{"no hits", "xxxx", "y", 0},
{"empty query", "foo", "", 0},
{"empty path", "", "f", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.CountBoundaryHitsForTest([]byte(tt.path), []byte(tt.query))
if got != tt.want {
t.Errorf("countBoundaryHits(%q, %q) = %d, want %d", tt.path, tt.query, got, tt.want)
}
})
}
}
func TestScorePath_NoSubsequenceReturnsZero(t *testing.T) {
t.Parallel()
path := []byte("src/internal/handler.go")
query := []byte("zzz")
tokens := [][]byte{[]byte("zzz")}
params := filefinder.DefaultScoreParamsForTest()
s := filefinder.ScorePathForTest(path, 13, 10, 2, query, tokens, params)
if s != 0 {
t.Errorf("expected 0 for no subsequence match, got %f", s)
}
}
func TestScorePath_ExactBasenameOverPartial(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("main")
tokens := [][]byte{query}
pathExact := []byte("src/main")
scoreExact := filefinder.ScorePathForTest(pathExact, 4, 4, 1, query, tokens, params)
pathPartial := []byte("module/amazing")
scorePartial := filefinder.ScorePathForTest(pathPartial, 7, 7, 1, query, tokens, params)
if scoreExact <= scorePartial {
t.Errorf("exact basename (%f) should score higher than partial (%f)", scoreExact, scorePartial)
}
}
func TestScorePath_BasenamePrefixOverScattered(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("han")
tokens := [][]byte{query}
pathPrefix := []byte("src/handler.go")
scorePrefix := filefinder.ScorePathForTest(pathPrefix, 4, 10, 1, query, tokens, params)
pathScattered := []byte("has/another/thing")
scoreScattered := filefinder.ScorePathForTest(pathScattered, 12, 5, 2, query, tokens, params)
if scorePrefix <= scoreScattered {
t.Errorf("basename prefix (%f) should score higher than scattered (%f)", scorePrefix, scoreScattered)
}
}
func TestScorePath_ShallowOverDeep(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("foo")
tokens := [][]byte{query}
pathShallow := []byte("src/foo.go")
scoreShallow := filefinder.ScorePathForTest(pathShallow, 4, 6, 1, query, tokens, params)
pathDeep := []byte("a/b/c/d/e/foo.go")
scoreDeep := filefinder.ScorePathForTest(pathDeep, 10, 6, 5, query, tokens, params)
if scoreShallow <= scoreDeep {
t.Errorf("shallow path (%f) should score higher than deep (%f)", scoreShallow, scoreDeep)
}
}
func TestScorePath_ShorterOverLongerSameMatch(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("foo")
tokens := [][]byte{query}
pathShort := []byte("x/foo")
scoreShort := filefinder.ScorePathForTest(pathShort, 2, 3, 1, query, tokens, params)
pathLong := []byte("x/foo_extremely_long_suffix_name")
scoreLong := filefinder.ScorePathForTest(pathLong, 2, 29, 1, query, tokens, params)
if scoreShort <= scoreLong {
t.Errorf("shorter path (%f) should score higher than longer (%f)", scoreShort, scoreLong)
}
}
func BenchmarkScorePath(b *testing.B) {
path := []byte("src/internal/coderd/database/queries/workspaces.sql")
query := []byte("workspace")
tokens := [][]byte{query}
params := filefinder.DefaultScoreParamsForTest()
baseOff, baseLen := filefinder.ExtractBasenameForTest(path)
s := filefinder.ScorePathForTest(path, baseOff, baseLen, 4, query, tokens, params)
if s == 0 {
b.Fatal("expected non-zero score for benchmark path")
}
b.ResetTimer()
for b.Loop() {
filefinder.ScorePathForTest(path, baseOff, baseLen, 4, query, tokens, params)
}
}
+210
View File
@@ -0,0 +1,210 @@
package filefinder
import (
"context"
"os"
"path/filepath"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"cdr.dev/slog/v3"
)
// FSEvent represents a filesystem change event.
type FSEvent struct {
Op FSEventOp
Path string
IsDir bool
}
// FSEventOp represents the type of filesystem operation.
type FSEventOp uint8
// Filesystem operations reported by the watcher.
const (
OpCreate FSEventOp = iota
OpRemove
OpRename
OpModify
)
var skipDirs = map[string]struct{}{
".git": {}, "node_modules": {}, ".hg": {}, ".svn": {},
"__pycache__": {}, ".cache": {}, ".venv": {}, "vendor": {}, ".terraform": {},
}
type fsWatcher struct {
w *fsnotify.Watcher
root string
events chan []FSEvent
logger slog.Logger
mu sync.Mutex
closed bool
done chan struct{}
}
func newFSWatcher(root string, logger slog.Logger) (*fsWatcher, error) {
w, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
return &fsWatcher{
w: w,
root: root,
events: make(chan []FSEvent, 64),
logger: logger,
done: make(chan struct{}),
}, nil
}
func (fw *fsWatcher) Start(ctx context.Context) {
initEvents := fw.addRecursive(fw.root)
if len(initEvents) > 0 {
select {
case fw.events <- initEvents:
case <-ctx.Done():
return
}
}
fw.logger.Debug(ctx, "fs watcher started", slog.F("root", fw.root))
go fw.loop(ctx)
}
func (fw *fsWatcher) Events() <-chan []FSEvent { return fw.events }
func (fw *fsWatcher) Close() error {
fw.mu.Lock()
if fw.closed {
fw.mu.Unlock()
return nil
}
fw.closed = true
fw.mu.Unlock()
err := fw.w.Close()
<-fw.done
return err
}
func (fw *fsWatcher) loop(ctx context.Context) {
defer close(fw.done)
const batchWindow = 50 * time.Millisecond
var (
batch []FSEvent
seen = make(map[string]struct{})
timer *time.Timer
timerC <-chan time.Time
)
flush := func() {
if len(batch) == 0 {
return
}
select {
case fw.events <- batch:
default:
fw.logger.Warn(ctx, "fs watcher dropping batch", slog.F("count", len(batch)))
}
batch = nil
seen = make(map[string]struct{})
if timer != nil {
timer.Stop()
}
timer = nil
timerC = nil
}
addToBatch := func(ev FSEvent) {
if _, dup := seen[ev.Path]; dup {
return
}
seen[ev.Path] = struct{}{}
batch = append(batch, ev)
if timer == nil {
timer = time.NewTimer(batchWindow)
timerC = timer.C
}
}
for {
select {
case <-ctx.Done():
flush()
return
case ev, ok := <-fw.w.Events:
if !ok {
flush()
return
}
fsev := translateEvent(ev)
if fsev == nil {
continue
}
if fsev.IsDir && fsev.Op == OpCreate {
for _, s := range fw.addRecursive(fsev.Path) {
addToBatch(s)
}
}
addToBatch(*fsev)
case err, ok := <-fw.w.Errors:
if !ok {
flush()
return
}
fw.logger.Warn(ctx, "fsnotify watcher error", slog.Error(err))
case <-timerC:
flush()
}
}
}
func (fw *fsWatcher) addRecursive(dir string) []FSEvent {
var events []FSEvent
_ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil //nolint:nilerr // best-effort
}
base := filepath.Base(path)
if _, skip := skipDirs[base]; skip && info.IsDir() {
return filepath.SkipDir
}
if info.IsDir() {
if addErr := fw.w.Add(path); addErr != nil {
fw.logger.Debug(context.Background(), "failed to add watch",
slog.F("path", path), slog.Error(addErr))
}
if path != dir {
events = append(events, FSEvent{Op: OpCreate, Path: path, IsDir: true})
}
return nil
}
events = append(events, FSEvent{Op: OpCreate, Path: path, IsDir: false})
return nil
})
return events
}
func translateEvent(ev fsnotify.Event) *FSEvent {
var op FSEventOp
switch {
case ev.Op&fsnotify.Create != 0:
op = OpCreate
case ev.Op&fsnotify.Remove != 0:
op = OpRemove
case ev.Op&fsnotify.Rename != 0:
op = OpRename
case ev.Op&fsnotify.Write != 0:
op = OpModify
default:
return nil
}
isDir := false
if op == OpCreate || op == OpModify {
fi, err := os.Lstat(ev.Name)
if err == nil {
isDir = fi.IsDir()
}
}
if isDir {
if _, skip := skipDirs[filepath.Base(ev.Name)]; skip {
return nil
}
}
return &FSEvent{Op: op, Path: ev.Name, IsDir: isDir}
}
+543 -329
View File
File diff suppressed because it is too large Load Diff
+20 -1
View File
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
}
repeated DisplayApp display_apps = 6;
optional bytes id = 7;
}
@@ -494,6 +494,24 @@ message ReportBoundaryLogsRequest {
message ReportBoundaryLogsResponse {}
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
message UpdateAppStatusRequest {
string slug = 1;
enum AppStatusState {
WORKING = 0;
IDLE = 1;
COMPLETE = 2;
FAILURE = 3;
}
AppStatusState state = 2;
string message = 3;
string uri = 4;
}
message UpdateAppStatusResponse {}
service Agent {
rpc GetManifest(GetManifestRequest) returns (Manifest);
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
@@ -512,4 +530,5 @@ service Agent {
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
}
+41 -1
View File
@@ -56,6 +56,7 @@ type DRPCAgentClient interface {
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type drpcAgentClient struct {
@@ -221,6 +222,15 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
return out, nil
}
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
out := new(UpdateAppStatusResponse)
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCAgentServer interface {
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
@@ -239,6 +249,7 @@ type DRPCAgentServer interface {
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type DRPCAgentUnimplementedServer struct{}
@@ -311,9 +322,13 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCAgentDescription struct{}
func (DRPCAgentDescription) NumMethods() int { return 17 }
func (DRPCAgentDescription) NumMethods() int { return 18 }
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
@@ -470,6 +485,15 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
in1.(*ReportBoundaryLogsRequest),
)
}, DRPCAgentServer.ReportBoundaryLogs, true
case 17:
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentServer).
UpdateAppStatus(
ctx,
in1.(*UpdateAppStatusRequest),
)
}, DRPCAgentServer.UpdateAppStatus, true
default:
return "", nil, nil, nil, false
}
@@ -750,3 +774,19 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
}
return x.CloseSend()
}
type DRPCAgent_UpdateAppStatusStream interface {
drpc.Stream
SendAndClose(*UpdateAppStatusResponse) error
}
type drpcAgent_UpdateAppStatusStream struct {
drpc.Stream
}
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
return err
}
return x.CloseSend()
}
+7 -3
View File
@@ -73,9 +73,13 @@ type DRPCAgentClient27 interface {
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
}
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
// message. Compatible with Coder v2.31+
// DRPCAgentClient28 is the Agent API at v2.8. It adds
// - a SubagentId field to the WorkspaceAgentDevcontainer message
// - an Id field to the CreateSubAgentRequest message.
// - UpdateAppStatus RPC.
//
// Compatible with Coder v2.31+
type DRPCAgentClient28 interface {
DRPCAgentClient27
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
+1 -1
View File
@@ -489,7 +489,7 @@ func workspaceAgent() *serpent.Command {
},
{
Flag: "socket-server-enabled",
Default: "false",
Default: "true",
Env: "CODER_AGENT_SOCKET_SERVER_ENABLED",
Description: "Enable the agent socket server.",
Value: serpent.BoolOf(&socketServerEnabled),
+4
View File
@@ -44,6 +44,7 @@ func TestWorkspaceAgent(t *testing.T) {
"--agent-token", r.AgentToken,
"--agent-url", client.URL.String(),
"--log-dir", logDir,
"--socket-path", testutil.AgentSocketPath(t),
)
clitest.Start(t, inv)
@@ -76,6 +77,7 @@ func TestWorkspaceAgent(t *testing.T) {
"--agent-token", r.AgentToken,
"--agent-url", client.URL.String(),
"--log-dir", logDir,
"--socket-path", testutil.AgentSocketPath(t),
)
// Set the subsystems for the agent.
inv.Environ.Set(agent.EnvAgentSubsystem, fmt.Sprintf("%s,%s", codersdk.AgentSubsystemExectrace, codersdk.AgentSubsystemEnvbox))
@@ -158,6 +160,7 @@ func TestWorkspaceAgent(t *testing.T) {
"--agent-header", "X-Testing=agent",
"--agent-header", "Cool-Header=Ethan was Here!",
"--agent-header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow",
"--socket-path", testutil.AgentSocketPath(t),
)
clitest.Start(t, agentInv)
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
@@ -199,6 +202,7 @@ func TestWorkspaceAgent(t *testing.T) {
"--pprof-address", "",
"--prometheus-address", "",
"--debug-address", "",
"--socket-path", testutil.AgentSocketPath(t),
)
clitest.Start(t, inv)
+9 -3
View File
@@ -30,9 +30,15 @@ func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.Te
_, _ = fmt.Fprint(inv.Stdout, "\033[1A")
var defaults []string
err = json.Unmarshal([]byte(templateVersionParameter.DefaultValue), &defaults)
if err != nil {
return "", err
defaultSource := defaultValue
if defaultSource == "" {
defaultSource = templateVersionParameter.DefaultValue
}
if defaultSource != "" {
err = json.Unmarshal([]byte(defaultSource), &defaults)
if err != nil {
return "", err
}
}
values, err := RichMultiSelect(inv, RichMultiSelectOptions{
+50 -45
View File
@@ -10,6 +10,7 @@ import (
"path/filepath"
"slices"
"strings"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
@@ -23,6 +24,7 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/toolsdk"
"github.com/coder/retry"
"github.com/coder/serpent"
)
@@ -539,7 +541,6 @@ func (r *RootCmd) mcpServer() *serpent.Command {
defer cancel()
defer srv.queue.Close()
cliui.Infof(inv.Stderr, "Failed to watch screen events")
// Start the reporter, watcher, and server. These are all tied to the
// lifetime of the MCP server, which is itself tied to the lifetime of the
// AI agent.
@@ -613,48 +614,51 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
}
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
return
}
go func() {
for {
select {
case <-ctx.Done():
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
// If the screen is stable, report idle.
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err == nil {
retrier.Reset()
loop:
for {
select {
case <-ctx.Done():
return
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
}
}
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
break loop
}
}
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
return
} else {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
}
}
}()
@@ -692,13 +696,14 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
// Add tool dependencies.
toolOpts := []func(*toolsdk.Deps){
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
// The agent does not reliably report its status correctly. If AgentAPI
// is enabled, we will always set the status to "working" when we get an
// MCP message, and rely on the screen watcher to eventually catch the
// idle state.
state := codersdk.WorkspaceAppStatusStateWorking
if s.aiAgentAPIClient == nil {
state = codersdk.WorkspaceAppStatusState(args.State)
state := codersdk.WorkspaceAppStatusState(args.State)
// The agent does not reliably report idle, so when AgentAPI is
// enabled we override idle to working and let the screen watcher
// detect the real idle via StatusStable. Final states (failure,
// complete) are trusted from the agent since the screen watcher
// cannot produce them.
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
state = codersdk.WorkspaceAppStatusStateWorking
}
return s.queue.Push(taskReport{
link: args.Link,
+185 -1
View File
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
},
},
},
// We ignore the state from the agent and assume "working".
// We override idle from the agent to working, but trust final states.
{
name: "IgnoreAgentState",
// AI agent reports that it is finished but the summary says it is doing
@@ -953,6 +953,46 @@ func TestExpMcpReporter(t *testing.T) {
Message: "finished",
},
},
// Agent reports failure; trusted even with AgentAPI enabled.
{
state: codersdk.WorkspaceAppStatusStateFailure,
summary: "something broke",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateFailure,
Message: "something broke",
},
},
// After failure, watcher reports stable -> idle.
{
event: makeStatusEvent(agentapi.StatusStable),
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "something broke",
},
},
},
},
// Final states pass through with AgentAPI enabled.
{
name: "AllowFinalStates",
tests: []test{
{
state: codersdk.WorkspaceAppStatusStateWorking,
summary: "doing work",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateWorking,
Message: "doing work",
},
},
// Agent reports complete; not overridden.
{
state: codersdk.WorkspaceAppStatusStateComplete,
summary: "all done",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateComplete,
Message: "all done",
},
},
},
},
// When AgentAPI is not being used, we accept agent state updates as-is.
@@ -1110,4 +1150,148 @@ func TestExpMcpReporter(t *testing.T) {
<-cmdDone
})
}
t.Run("Reconnect", func(t *testing.T) {
t.Parallel()
// Create a test deployment and workspace.
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user2.ID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{
{
Slug: "vscode",
},
}
return a
}).Do()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
// Watch the workspace for changes.
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
require.NoError(t, err)
var lastAppStatus codersdk.WorkspaceAppStatus
nextUpdate := func() codersdk.WorkspaceAppStatus {
for {
select {
case <-ctx.Done():
require.FailNow(t, "timed out waiting for status update")
case w, ok := <-watcher:
require.True(t, ok, "watch channel closed")
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
lastAppStatus = *w.LatestAppStatus
return lastAppStatus
}
}
}
}
// Mock AI AgentAPI server that supports disconnect/reconnect.
disconnect := make(chan struct{})
listening := make(chan func(sse codersdk.ServerSentEvent) error)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a cancelable context so we can stop the SSE sender
// goroutine on disconnect without waiting for the HTTP
// serve loop to cancel r.Context().
sseCtx, sseCancel := context.WithCancel(r.Context())
defer sseCancel()
r = r.WithContext(sseCtx)
send, closed, err := httpapi.ServerSentEventSender(w, r)
if err != nil {
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
Detail: err.Error(),
})
return
}
// Send initial message so the watcher knows the agent is active.
send(*makeMessageEvent(0, agentapi.RoleAgent))
select {
case listening <- send:
case <-r.Context().Done():
return
}
select {
case <-closed:
case <-disconnect:
sseCancel()
<-closed
}
}))
t.Cleanup(srv.Close)
inv, _ := clitest.New(t,
"exp", "mcp", "server",
"--agent-url", client.URL.String(),
"--agent-token", r.AgentToken,
"--app-status-slug", "vscode",
"--allowed-tools=coder_report_task",
"--ai-agentapi-url", srv.URL,
)
inv = inv.WithContext(ctx)
pty := ptytest.New(t)
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
stderr := ptytest.New(t)
inv.Stderr = stderr.Output()
// Run the MCP server.
clitest.Start(t, inv)
// Initialize.
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
pty.WriteLine(payload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore init response
// Get first sender from the initial SSE connection.
sender := testutil.RequireReceive(ctx, t, listening)
// Self-report a working status via tool call.
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got := nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "doing work", got.Message)
// Watcher sends stable, verify idle is reported.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
// Disconnect the SSE connection by signaling the handler to return.
testutil.RequireSend(ctx, t, disconnect, struct{}{})
// Wait for the watcher to reconnect and get the new sender.
sender = testutil.RequireReceive(ctx, t, listening)
// After reconnect, self-report a working status again.
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "reconnected", got.Message)
// Verify the watcher still processes events after reconnect.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
cancel()
})
}
+12
View File
@@ -29,6 +29,7 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
templateVersionJobTimeout time.Duration
prebuildWorkspaceTimeout time.Duration
noCleanup bool
provisionerTags []string
tracingFlags = &scaletestTracingFlags{}
timeoutStrategy = &timeoutFlags{}
@@ -111,10 +112,16 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
tags, err := ParseProvisionerTags(provisionerTags)
if err != nil {
return err
}
for i := range numTemplates {
id := strconv.Itoa(int(i))
cfg := prebuilds.Config{
OrganizationID: me.OrganizationIDs[0],
ProvisionerTags: tags,
NumPresets: int(numPresets),
NumPresetPrebuilds: int(numPresetPrebuilds),
TemplateVersionJobTimeout: templateVersionJobTimeout,
@@ -283,6 +290,11 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
Description: "Skip cleanup (deletion test) and leave resources intact.",
Value: serpent.BoolOf(&noCleanup),
},
{
Flag: "provisioner-tag",
Description: "Specify a set of tags to target provisioner daemons.",
Value: serpent.StringArrayOf(&provisionerTags),
},
}
tracingFlags.attach(&cmd.Options)
+46 -1
View File
@@ -4,6 +4,9 @@ import (
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"time"
"golang.org/x/xerrors"
@@ -16,6 +19,29 @@ import (
"github.com/coder/serpent"
)
// detectGitRef attempts to resolve the current git branch and remote
// origin URL from the given working directory. These are sent to the
// control plane so it can look up PR/diff status via the GitHub API
// without SSHing into the workspace. Failures are silently ignored
// since this is best-effort.
func detectGitRef(workingDirectory string) (branch string, remoteOrigin string) {
run := func(args ...string) string {
//nolint:gosec
cmd := exec.Command(args[0], args[1:]...)
if workingDirectory != "" {
cmd.Dir = workingDirectory
}
out, err := cmd.Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
branch = run("git", "rev-parse", "--abbrev-ref", "HEAD")
remoteOrigin = run("git", "config", "--get", "remote.origin.url")
return branch, remoteOrigin
}
// gitAskpass is used by the Coder agent to automatically authenticate
// with Git providers based on a hostname.
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
@@ -38,8 +64,21 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("create agent client: %w", err)
}
workingDirectory, err := os.Getwd()
if err != nil {
workingDirectory = ""
}
// Detect the current git branch and remote origin so
// the control plane can resolve diffs without needing
// to SSH back into the workspace.
gitBranch, gitRemoteOrigin := detectGitRef(workingDirectory)
token, err := client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
Match: host,
Match: host,
GitBranch: gitBranch,
GitRemoteOrigin: gitRemoteOrigin,
ChatID: inv.Environ.Get("CODER_CHAT_ID"),
})
if err != nil {
var apiError *codersdk.Error
@@ -58,6 +97,12 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("get git token: %w", err)
}
if token.URL != "" {
// This is to help the agent authenticate with Git.
if inv.Environ.Get("CODER_CHAT_AGENT") == "true" {
_, _ = fmt.Fprintf(inv.Stderr, `You must notify the user to authenticate with Git.\n\nThe URL is: %s\n`, token.URL)
return cliui.ErrCanceled
}
if err := openURL(inv, token.URL); err == nil {
cliui.Infof(inv.Stderr, "Your browser has been opened to authenticate with Git:\n%s", token.URL)
} else {
+29 -5
View File
@@ -1,6 +1,7 @@
package cli
import (
"encoding/json"
"fmt"
"strings"
@@ -231,7 +232,7 @@ next:
continue // immutables should not be passed to consecutive builds
}
if len(tvp.Options) > 0 && !isValidTemplateParameterOption(buildParameter, tvp.Options) {
if len(tvp.Options) > 0 && !isValidTemplateParameterOption(buildParameter, *tvp) {
continue // do not propagate invalid options
}
@@ -297,7 +298,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
}
if !tvp.Mutable && action != WorkspaceCreate {
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
}
}
@@ -365,7 +366,7 @@ func (pr *ParameterResolver) isLastBuildParameterInvalidOption(templateVersionPa
for _, buildParameter := range pr.lastBuildParameters {
if buildParameter.Name == templateVersionParameter.Name {
return !isValidTemplateParameterOption(buildParameter, templateVersionParameter.Options)
return !isValidTemplateParameterOption(buildParameter, templateVersionParameter)
}
}
return false
@@ -389,8 +390,31 @@ func findWorkspaceBuildParameter(parameterName string, params []codersdk.Workspa
return nil
}
func isValidTemplateParameterOption(buildParameter codersdk.WorkspaceBuildParameter, options []codersdk.TemplateVersionParameterOption) bool {
for _, opt := range options {
func isValidTemplateParameterOption(buildParameter codersdk.WorkspaceBuildParameter, templateVersionParameter codersdk.TemplateVersionParameter) bool {
// Multi-select parameters store values as a JSON array (e.g.
// '["vim","emacs"]'), so we need to parse the array and validate
// each element individually against the allowed options.
if templateVersionParameter.Type == "list(string)" {
var values []string
if err := json.Unmarshal([]byte(buildParameter.Value), &values); err != nil {
return false
}
for _, v := range values {
found := false
for _, opt := range templateVersionParameter.Options {
if opt.Value == v {
found = true
break
}
}
if !found {
return false
}
}
return true
}
for _, opt := range templateVersionParameter.Options {
if opt.Value == buildParameter.Value {
return true
}
+85
View File
@@ -0,0 +1,85 @@
package cli
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/codersdk"
)
func TestIsValidTemplateParameterOption(t *testing.T) {
t.Parallel()
options := []codersdk.TemplateVersionParameterOption{
{Name: "Vim", Value: "vim"},
{Name: "Emacs", Value: "emacs"},
{Name: "VS Code", Value: "vscode"},
}
t.Run("SingleSelectValid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editor", Value: "vim"}
tvp := codersdk.TemplateVersionParameter{
Name: "editor",
Type: "string",
Options: options,
}
assert.True(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("SingleSelectInvalid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editor", Value: "notepad"}
tvp := codersdk.TemplateVersionParameter{
Name: "editor",
Type: "string",
Options: options,
}
assert.False(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectAllValid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `["vim","emacs"]`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.True(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectOneInvalid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `["vim","notepad"]`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.False(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectEmptyArray", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `[]`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.True(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectInvalidJSON", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `not-json`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.False(t, isValidTemplateParameterOption(bp, tvp))
})
}
+26 -6
View File
@@ -2,6 +2,7 @@ package cli_test
import (
"bytes"
"cmp"
"context"
"database/sql"
"encoding/json"
@@ -20,7 +21,6 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk"
)
@@ -35,7 +35,10 @@ func TestProvisioners_Golden(t *testing.T) {
provisioners, err := coderdAPI.Database.GetProvisionerDaemons(systemCtx)
require.NoError(t, err)
slices.SortFunc(provisioners, func(a, b database.ProvisionerDaemon) int {
return a.CreatedAt.Compare(b.CreatedAt)
return cmp.Or(
a.CreatedAt.Compare(b.CreatedAt),
bytes.Compare(a.ID[:], b.ID[:]),
)
})
pIdx := 0
for _, p := range provisioners {
@@ -47,7 +50,10 @@ func TestProvisioners_Golden(t *testing.T) {
jobs, err := coderdAPI.Database.GetProvisionerJobsCreatedAfter(systemCtx, time.Time{})
require.NoError(t, err)
slices.SortFunc(jobs, func(a, b database.ProvisionerJob) int {
return a.CreatedAt.Compare(b.CreatedAt)
return cmp.Or(
a.CreatedAt.Compare(b.CreatedAt),
bytes.Compare(a.ID[:], b.ID[:]),
)
})
jIdx := 0
for _, j := range jobs {
@@ -76,11 +82,15 @@ func TestProvisioners_Golden(t *testing.T) {
firstProvisioner := coderdtest.NewTaggedProvisionerDaemon(t, coderdAPI, "default-provisioner", map[string]string{"owner": "", "scope": "organization"})
t.Cleanup(func() { _ = firstProvisioner.Close() })
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, completeWithAgent())
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
version = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
require.Equal(t, codersdk.ProvisionerJobSucceeded, version.Job.Status,
"template version import should succeed, got error: %s", version.Job.Error)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
wb := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
require.Equal(t, codersdk.ProvisionerJobSucceeded, wb.Job.Status,
"workspace build job should succeed, got error: %s", wb.Job.Error)
// Stop the provisioner so it doesn't grab any more jobs.
firstProvisioner.Close()
@@ -89,7 +99,17 @@ func TestProvisioners_Golden(t *testing.T) {
replace[version.ID.String()] = "00000000-0000-0000-cccc-000000000000"
replace[workspace.LatestBuild.ID.String()] = "00000000-0000-0000-dddd-000000000000"
now := dbtime.Now()
// Base synthetic times off the latest real job's CreatedAt, not the
// wall clock. Using dbtime.Now() here is racy because NTP clock
// steps can make it return a time before the real jobs' CreatedAt.
systemCtx := dbauthz.AsSystemRestricted(context.Background())
existingJobs, err := coderdAPI.Database.GetProvisionerJobsCreatedAfter(systemCtx, time.Time{})
require.NoError(t, err)
require.NotEmpty(t, existingJobs, "expected at least one provisioner job")
latestJob := slices.MaxFunc(existingJobs, func(a, b database.ProvisionerJob) int {
return a.CreatedAt.Compare(b.CreatedAt)
})
now := latestJob.CreatedAt.Add(time.Second)
// Create a provisioner that's working on a job.
pd1 := dbgen.ProvisionerDaemon(t, coderdAPI.Database, database.ProvisionerDaemon{
+112 -44
View File
@@ -617,28 +617,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
}
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
promRegistry := prometheus.NewRegistry()
oauthInstrument := promoauth.NewFactory(promRegistry)
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
externalAuthConfigs, err := externalauth.ConvertConfig(
oauthInstrument,
vals.ExternalAuthConfigs.Value,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range externalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
if err != nil {
@@ -669,7 +649,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
Pubsub: nil,
CacheDir: cacheDir,
GoogleTokenValidator: googleTokenValidator,
ExternalAuthConfigs: externalAuthConfigs,
ExternalAuthConfigs: nil,
RealIPConfig: realIPConfig,
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TracerProvider: tracerProvider,
@@ -829,9 +809,43 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("set deployment id: %w", err)
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
mergedExternalAuthProviders := append([]codersdk.ExternalAuthConfig{}, vals.ExternalAuthConfigs.Value...)
mergedExternalAuthProviders = append(mergedExternalAuthProviders, extAuthEnv...)
vals.ExternalAuthConfigs.Value = mergedExternalAuthProviders
mergedExternalAuthProviders, err = maybeAppendDefaultGithubExternalAuthProvider(
ctx,
options.Logger,
options.Database,
vals,
mergedExternalAuthProviders,
)
if err != nil {
return xerrors.Errorf("maybe append default github external auth provider: %w", err)
}
options.ExternalAuthConfigs, err = externalauth.ConvertConfig(
oauthInstrument,
mergedExternalAuthProviders,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range options.ExternalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
// Manage push notifications.
experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value())
if experiments.Enabled(codersdk.ExperimentWebPush) {
if experiments.Enabled(codersdk.ExperimentWebPush) || buildinfo.IsDev() {
if !strings.HasPrefix(options.AccessURL.String(), "https://") {
options.Logger.Warn(ctx, "access URL is not HTTPS, so web push notifications may not work on some browsers", slog.F("access_url", options.AccessURL.String()))
}
@@ -1926,6 +1940,79 @@ type githubOAuth2ConfigParams struct {
enterpriseBaseURL string
}
func isDeploymentEligibleForGithubDefaultProvider(ctx context.Context, db database.Store) (bool, error) {
// We want to enable the default provider only for new deployments, and avoid
// enabling it if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return false, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return false, xerrors.Errorf("upsert github default eligible: %w", err)
}
}
return defaultEligible, nil
}
func maybeAppendDefaultGithubExternalAuthProvider(
ctx context.Context,
logger slog.Logger,
db database.Store,
vals *codersdk.DeploymentValues,
mergedExplicitProviders []codersdk.ExternalAuthConfig,
) ([]codersdk.ExternalAuthConfig, error) {
if !vals.ExternalAuthGithubDefaultProviderEnable.Value() {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "disabled by configuration"),
slog.F("flag", "external-auth-github-default-provider-enable"),
)
return mergedExplicitProviders, nil
}
if len(mergedExplicitProviders) > 0 {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "explicit external auth providers configured"),
slog.F("provider_count", len(mergedExplicitProviders)),
)
return mergedExplicitProviders, nil
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
}
if !defaultEligible {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "deployment is not eligible"),
)
return mergedExplicitProviders, nil
}
logger.Info(ctx, "injecting default github external auth provider",
slog.F("type", codersdk.EnhancedExternalAuthProviderGitHub.String()),
slog.F("client_id", GithubOAuth2DefaultProviderClientID),
slog.F("device_flow", GithubOAuth2DefaultProviderDeviceFlow),
)
return append(mergedExplicitProviders, codersdk.ExternalAuthConfig{
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
ClientID: GithubOAuth2DefaultProviderClientID,
DeviceFlow: GithubOAuth2DefaultProviderDeviceFlow,
}), nil
}
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
params := githubOAuth2ConfigParams{
accessURL: vals.AccessURL.Value(),
@@ -1950,28 +2037,9 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
return nil, nil //nolint:nilnil
}
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
// We want to enable it only for new deployments, and avoid enabling it
// if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return nil, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
}
if !defaultEligible {
+151
View File
@@ -53,6 +53,7 @@ import (
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/userpassword"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/pty/ptytest"
@@ -302,6 +303,7 @@ func TestServer(t *testing.T) {
"open install.sh: file does not exist",
"telemetry disabled, unable to notify of security issues",
"installed terraform version newer than expected",
"report generator",
}
countLines := func(fullOutput string) int {
@@ -1805,6 +1807,155 @@ func TestServer(t *testing.T) {
})
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_ExternalAuthGitHubDefaultProvider(t *testing.T) {
type testCase struct {
name string
args []string
env map[string]string
createUserPreStart bool
expectedProviders []string
}
run := func(t *testing.T, tc testCase) {
ctx := testutil.Context(t, testutil.WaitLong)
unsetPrefixedEnv := func(prefix string) {
t.Helper()
for _, envVar := range os.Environ() {
envKey, _, found := strings.Cut(envVar, "=")
if !found || !strings.HasPrefix(envKey, prefix) {
continue
}
value, had := os.LookupEnv(envKey)
require.True(t, had)
require.NoError(t, os.Unsetenv(envKey))
keyCopy := envKey
valueCopy := value
t.Cleanup(func() {
// This is for setting/unsetting a number of prefixed env vars.
// t.Setenv doesn't cover this use case.
// nolint:usetesting
_ = os.Setenv(keyCopy, valueCopy)
})
}
}
unsetPrefixedEnv("CODER_EXTERNAL_AUTH_")
unsetPrefixedEnv("CODER_GITAUTH_")
dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
const (
existingUserEmail = "existing-user@coder.com"
existingUserUsername = "existing-user"
existingUserPassword = "SomeSecurePassword!"
)
if tc.createUserPreStart {
hashedPassword, err := userpassword.Hash(existingUserPassword)
require.NoError(t, err)
_ = dbgen.User(t, db, database.User{
Email: existingUserEmail,
Username: existingUserUsername,
HashedPassword: []byte(hashedPassword),
})
}
args := []string{
"server",
"--postgres-url", dbURL,
"--http-address", ":0",
"--access-url", "https://example.com",
}
args = append(args, tc.args...)
inv, cfg := clitest.New(t, args...)
for envKey, value := range tc.env {
t.Setenv(envKey, value)
}
clitest.Start(t, inv)
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
if tc.createUserPreStart {
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
Email: existingUserEmail,
Password: existingUserPassword,
})
require.NoError(t, err)
client.SetSessionToken(loginResp.SessionToken)
} else {
_ = coderdtest.CreateFirstUser(t, client)
}
externalAuthResp, err := client.ListExternalAuths(ctx)
require.NoError(t, err)
gotProviders := map[string]codersdk.ExternalAuthLinkProvider{}
for _, provider := range externalAuthResp.Providers {
gotProviders[provider.ID] = provider
}
require.Len(t, gotProviders, len(tc.expectedProviders))
for _, providerID := range tc.expectedProviders {
provider, ok := gotProviders[providerID]
require.Truef(t, ok, "expected provider %q to be configured", providerID)
if providerID == codersdk.EnhancedExternalAuthProviderGitHub.String() {
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), provider.Type)
require.True(t, provider.Device)
}
}
}
for _, tc := range []testCase{
{
name: "NewDeployment_NoExplicitProviders_InjectsDefaultGithub",
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitHub.String()},
},
{
name: "ExistingDeployment_DoesNotInjectDefaultGithub",
createUserPreStart: true,
expectedProviders: nil,
},
{
name: "DefaultProviderDisabled_DoesNotInjectDefaultGithub",
args: []string{
"--external-auth-github-default-provider-enable=false",
},
expectedProviders: nil,
},
{
name: "ExplicitProviderViaConfig_DoesNotInjectDefaultGithub",
args: []string{
`--external-auth-providers=[{"type":"gitlab","client_id":"config-client-id"}]`,
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_EXTERNAL_AUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_EXTERNAL_AUTH_0_CLIENT_ID": "env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaLegacyEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_GITAUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_GITAUTH_0_CLIENT_ID": "legacy-env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
} {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_Logging_NoParallel(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+1 -1
View File
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
version := workspace.LatestBuild.TemplateVersionID
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
version = workspace.TemplateActiveVersionID
if version != workspace.LatestBuild.TemplateVersionID {
action = WorkspaceUpdate
+4 -4
View File
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Do()
statefilePath := filepath.Join(t.TempDir(), "state")
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Do()
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
var gotState bytes.Buffer
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Do()
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
Do()
wantState := []byte("updated state")
stateFile, err := os.CreateTemp(t.TempDir(), "")
+1 -1
View File
@@ -74,7 +74,7 @@ OPTIONS:
--socket-path string, $CODER_AGENT_SOCKET_PATH
Specify the path for the agent socket.
--socket-server-enabled bool, $CODER_AGENT_SOCKET_SERVER_ENABLED (default: false)
--socket-server-enabled bool, $CODER_AGENT_SOCKET_SERVER_ENABLED (default: true)
Enable the agent socket server.
--ssh-max-timeout duration, $CODER_AGENT_SSH_MAX_TIMEOUT (default: 72h)
+6 -1
View File
@@ -62,6 +62,9 @@ OPTIONS:
Separate multiple experiments with commas, or enter '*' to opt-in to
all available experiments.
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
Enable the default GitHub external auth provider managed by Coder.
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
Type of auth to use when connecting to postgres. For AWS RDS, using
IAM authentication (awsiamrds) is recommended.
@@ -391,7 +394,9 @@ NETWORKING OPTIONS:
--host-prefix-cookie bool, $CODER_HOST_PREFIX_COOKIE (default: false)
Recommended to be enabled. Enables `__Host-` prefix for cookies to
guarantee they are only set by the right domain.
guarantee they are only set by the right domain. This change is
disruptive to any workspaces built before release 2.31, requiring a
workspace restart.
NETWORKING / DERP OPTIONS:
Most Coder deployments never have to think about DERP because all connections
+5 -1
View File
@@ -182,7 +182,8 @@ networking:
# (default: lax, type: enum[lax\|none])
sameSiteAuthCookie: lax
# Recommended to be enabled. Enables `__Host-` prefix for cookies to guarantee
# they are only set by the right domain.
# they are only set by the right domain. This change is disruptive to any
# workspaces built before release 2.31, requiring a workspace restart.
# (default: false, type: bool)
hostPrefixCookie: false
# Whether Coder only allows connections to workspaces via the browser.
@@ -563,6 +564,9 @@ supportLinks: []
# External Authentication providers.
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
externalAuthProviders: []
# Enable the default GitHub external auth provider managed by Coder.
# (default: true, type: bool)
externalAuthGithubDefaultProviderEnable: true
# Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By
# default, this will pick the best available wgtunnel server hosted by Coder. e.g.
# "tunnel.example.com".
+2 -15
View File
@@ -241,26 +241,13 @@ func (r *RootCmd) listTokens() *serpent.Command {
}
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
IncludeAll: all,
IncludeAll: all,
IncludeExpired: includeExpired,
})
if err != nil {
return xerrors.Errorf("list tokens: %w", err)
}
// Filter out expired tokens unless --include-expired is set
// TODO(Cian): This _could_ get too big for client-side filtering.
// If it causes issues, we can filter server-side.
if !includeExpired {
now := time.Now()
filtered := make([]codersdk.APIKeyWithOwner, 0, len(tokens))
for _, token := range tokens {
if token.ExpiresAt.After(now) {
filtered = append(filtered, token)
}
}
tokens = filtered
}
displayTokens = make([]tokenListRow, len(tokens))
for i, token := range tokens {
+70
View File
@@ -990,4 +990,74 @@ func TestUpdateValidateRichParameters(t *testing.T) {
_ = testutil.TryReceive(ctx, t, doneChan)
})
t.Run("NewImmutableParameterViaFlag", func(t *testing.T) {
t.Parallel()
// Create template and workspace with only a mutable parameter.
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
templateParameters := []*proto.RichParameter{
{Name: stringParameterName, Type: "string", Mutable: true, Required: true, Options: []*proto.RichParameterOption{
{Name: "First option", Description: "This is first option", Value: "1st"},
{Name: "Second option", Description: "This is second option", Value: "2nd"},
}},
}
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(templateParameters))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
inv, root := clitest.New(t, "create", "my-workspace", "--yes", "--template", template.Name, "--parameter", fmt.Sprintf("%s=%s", stringParameterName, "1st"))
clitest.SetupConfig(t, member, root)
err := inv.Run()
require.NoError(t, err)
// Update template: add a new immutable parameter.
updatedTemplateParameters := []*proto.RichParameter{
templateParameters[0],
{Name: immutableParameterName, Type: "string", Mutable: false, Required: true, Options: []*proto.RichParameterOption{
{Name: "fir", Description: "First option for immutable parameter", Value: "I"},
{Name: "sec", Description: "Second option for immutable parameter", Value: "II"},
}},
}
updatedVersion := coderdtest.UpdateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(updatedTemplateParameters), template.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, updatedVersion.ID)
err = client.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{
ID: updatedVersion.ID,
})
require.NoError(t, err)
// Update workspace, supplying the new immutable parameter via
// the --parameter flag. This should succeed because it's the
// first time this parameter is being set.
inv, root = clitest.New(t, "update", "my-workspace",
"--parameter", fmt.Sprintf("%s=%s", immutableParameterName, "II"))
clitest.SetupConfig(t, member, root)
pty := ptytest.New(t).Attach(inv)
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatch("Planning workspace")
ctx := testutil.Context(t, testutil.WaitLong)
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify the immutable parameter was set correctly.
workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
require.NoError(t, err)
require.Contains(t, actualParameters, codersdk.WorkspaceBuildParameter{
Name: immutableParameterName,
Value: "II",
})
})
}
+2
View File
@@ -179,6 +179,8 @@ func New(opts Options, workspace database.Workspace) *API {
Database: opts.Database,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
Clock: opts.Clock,
NotificationsEnqueuer: opts.NotificationsEnqueuer,
}
api.MetadataAPI = &MetadataAPI{
+240
View File
@@ -2,6 +2,10 @@ package agentapi
import (
"context"
"database/sql"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
@@ -9,7 +13,14 @@ import (
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/notifications"
strutil "github.com/coder/coder/v2/coderd/util/strings"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
type AppsAPI struct {
@@ -17,6 +28,8 @@ type AppsAPI struct {
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
NotificationsEnqueuer notifications.Enqueuer
Clock quartz.Clock
}
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
@@ -104,3 +117,230 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
}
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
if len(req.Message) > 160 {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Message is too long.",
Detail: "Message must be less than 160 characters.",
Validations: []codersdk.ValidationError{
{Field: "message", Detail: "Message must be less than 160 characters."},
},
})
}
var dbState database.WorkspaceAppStatusState
switch req.State {
case agentproto.UpdateAppStatusRequest_COMPLETE:
dbState = database.WorkspaceAppStatusStateComplete
case agentproto.UpdateAppStatusRequest_FAILURE:
dbState = database.WorkspaceAppStatusStateFailure
case agentproto.UpdateAppStatusRequest_WORKING:
dbState = database.WorkspaceAppStatusStateWorking
case agentproto.UpdateAppStatusRequest_IDLE:
dbState = database.WorkspaceAppStatusStateIdle
default:
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Invalid state provided.",
Detail: fmt.Sprintf("invalid state: %q", req.State),
Validations: []codersdk.ValidationError{
{Field: "state", Detail: "State must be one of: complete, failure, working, idle."},
},
})
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: workspaceAgent.ID,
Slug: req.Slug,
})
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace app.",
Detail: fmt.Sprintf("No app found with slug %q", req.Slug),
})
}
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
Detail: err.Error(),
})
}
// Treat the message as untrusted input.
cleaned := strutil.UISanitize(req.Message)
// Get the latest status for the workspace app to detect no-op updates
// nolint:gocritic // This is a system restricted operation.
latestAppStatus, err := a.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get latest workspace app status.",
Detail: err.Error(),
})
}
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
// nolint:gocritic // This is a system restricted operation.
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
WorkspaceID: workspace.ID,
AgentID: workspaceAgent.ID,
AppID: app.ID,
State: dbState,
Message: cleaned,
Uri: sql.NullString{
String: req.Uri,
Valid: req.Uri != "",
},
})
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to insert workspace app status.",
Detail: err.Error(),
})
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to publish workspace update.",
Detail: err.Error(),
})
}
}
// Notify on state change to Working/Idle for AI tasks
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
if shouldBump(dbState, latestAppStatus) {
// We pass time.Time{} for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic handles this by
// defaulting to the template's activity_bump duration (typically 1 hour).
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
}
// just return a blank response because it doesn't contain any settable fields at present.
return new(agentproto.UpdateAppStatusResponse), nil
}
func shouldBump(dbState database.WorkspaceAppStatusState, latestAppStatus database.WorkspaceAppStatus) bool {
// Bump deadline when agent reports working or transitions away from working.
// This prevents auto-pause during active work and gives users time to interact
// after work completes.
// Bump if reporting working state.
if dbState == database.WorkspaceAppStatusStateWorking {
return true
}
// Bump if transitioning away from working state.
if latestAppStatus.ID != uuid.Nil {
prevState := latestAppStatus.State
if prevState == database.WorkspaceAppStatusStateWorking {
return true
}
}
return false
}
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
// transitions to Working or Idle.
// No-op if:
// - the workspace agent app isn't configured as an AI task,
// - the new state equals the latest persisted state,
// - the workspace agent is not ready (still starting up).
func (a *AppsAPI) enqueueAITaskStateNotification(
ctx context.Context,
appID uuid.UUID,
latestAppStatus database.WorkspaceAppStatus,
newAppStatus database.WorkspaceAppStatusState,
workspace database.Workspace,
agent database.WorkspaceAgent,
) {
var notificationTemplate uuid.UUID
switch newAppStatus {
case database.WorkspaceAppStatusStateWorking:
notificationTemplate = notifications.TemplateTaskWorking
case database.WorkspaceAppStatusStateIdle:
notificationTemplate = notifications.TemplateTaskIdle
case database.WorkspaceAppStatusStateComplete:
notificationTemplate = notifications.TemplateTaskCompleted
case database.WorkspaceAppStatusStateFailure:
notificationTemplate = notifications.TemplateTaskFailed
default:
// Not a notifiable state, do nothing
return
}
if !workspace.TaskID.Valid {
// Workspace has no task ID, do nothing.
return
}
// Only send notifications when the agent is ready. We want to skip
// any state transitions that occur whilst the workspace is starting
// up as it doesn't make sense to receive them.
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
a.Log.Debug(ctx, "skipping AI task notification because agent is not ready",
slog.F("agent_id", agent.ID),
slog.F("lifecycle_state", agent.LifecycleState),
slog.F("new_app_status", newAppStatus),
)
return
}
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
if err != nil {
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
return
}
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
// Non-task app, do nothing.
return
}
// Skip if the latest persisted state equals the new state (no new transition)
// Note: uuid.Nil check is valid here. If no previous status exists,
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == newAppStatus {
return
}
// Skip the initial "Working" notification when the task first starts.
// This is obvious to the user since they just created the task.
// We still notify on the first "Idle" status and all subsequent transitions.
if latestAppStatus.ID == uuid.Nil && newAppStatus == database.WorkspaceAppStatusStateWorking {
return
}
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
// nolint:gocritic // Need notifier actor to enqueue notifications
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notificationTemplate,
map[string]string{
"task": task.Name,
"workspace": workspace.Name,
},
map[string]any{
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
// allowing identical content to resend within the same day
// (but not more than once every 10s).
"dedupe_bypass_ts": a.Clock.Now().UTC().Truncate(time.Minute),
},
"api-workspace-agent-app-status",
// Associate this notification with related entities
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
); err != nil {
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
return
}
}
+115
View File
@@ -0,0 +1,115 @@
package agentapi
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/ptr"
)
func TestShouldBump(t *testing.T) {
t.Parallel()
tests := []struct {
name string
prevState *database.WorkspaceAppStatusState // nil means no previous state
newState database.WorkspaceAppStatusState
shouldBump bool
}{
{
name: "FirstStatusBumps",
prevState: nil,
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "WorkingToIdleBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: true,
},
{
name: "WorkingToCompleteBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: true,
},
{
name: "CompleteToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "CompleteToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "FailureToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "FailureToFailureNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: false,
},
{
name: "CompleteToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "FailureToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "WorkingToFailureBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: true,
},
{
name: "IdleToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "IdleToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var prevAppStatus database.WorkspaceAppStatus
// If there's a previous state, report it first.
if tt.prevState != nil {
prevAppStatus.ID = uuid.UUID{1}
prevAppStatus.State = *tt.prevState
}
didBump := shouldBump(tt.newState, prevAppStatus)
if tt.shouldBump {
require.True(t, didBump, "wanted deadline to bump but it didn't")
} else {
require.False(t, didBump, "wanted deadline not to bump but it did")
}
})
}
}
+188
View File
@@ -2,9 +2,13 @@ package agentapi_test
import (
"context"
"database/sql"
"net/http"
"strings"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
@@ -12,8 +16,12 @@ import (
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestBatchUpdateAppHealths(t *testing.T) {
@@ -253,3 +261,183 @@ func TestBatchUpdateAppHealths(t *testing.T) {
require.Nil(t, resp)
})
}
func TestWorkspaceAgentAppStatus(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
fEnq := &notificationstest.FakeEnqueuer{}
mClock := quartz.NewMock(t)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, *agnt, agent)
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
return nil
},
NotificationsEnqueuer: fEnq,
Clock: mClock,
}
app := database.WorkspaceApp{
ID: uuid.UUID{8},
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: agent.ID,
Slug: "vscode",
}).Times(1).Return(app, nil)
task := database.Task{
ID: uuid.UUID{7},
WorkspaceAppID: uuid.NullUUID{
Valid: true,
UUID: app.ID,
},
}
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: task.ID,
},
}
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
appStatus := database.WorkspaceAppStatus{
ID: uuid.UUID{6},
}
mDB.EXPECT().GetLatestWorkspaceAppStatusByAppID(gomock.Any(), app.ID).Times(1).Return(appStatus, nil)
mDB.EXPECT().InsertWorkspaceAppStatus(
gomock.Any(),
gomock.Cond(func(params database.InsertWorkspaceAppStatusParams) bool {
if params.AgentID == agent.ID && params.AppID == app.ID {
assert.Equal(t, "testing", params.Message)
assert.Equal(t, database.WorkspaceAppStatusStateComplete, params.State)
assert.True(t, params.Uri.Valid)
assert.Equal(t, "https://example.com", params.Uri.String)
return true
}
return false
})).Times(1).Return(database.WorkspaceAppStatus{}, nil)
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.NoError(t, err)
kind := testutil.RequireReceive(ctx, t, workspaceUpdates)
require.Equal(t, wspubsub.WorkspaceEventKindAgentAppStatusUpdate, kind)
sent := fEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskCompleted))
require.Len(t, sent, 1)
})
t.Run("FailUnknownApp", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), gomock.Any()).
Times(1).
Return(database.WorkspaceApp{}, sql.ErrNoRows)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "unknown",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "No app found with slug")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailUnknownState", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: 77,
})
require.ErrorContains(t, err, "Invalid state")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailTooLong", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: strings.Repeat("a", 161),
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "Message is too long")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
}
+4 -3
View File
@@ -192,7 +192,8 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
})
defer commitAuditWS()
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, r, &createWorkspaceOptions{
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, &createWorkspaceOptions{
remoteAddr: r.RemoteAddr,
// Before creating the workspace, ensure that this task can be created.
preCreateInTX: func(ctx context.Context, tx database.Store) error {
// Create task record in the database before creating the workspace so that
@@ -1248,7 +1249,7 @@ func (api *API) postWorkspaceAgentTaskLogSnapshot(rw http.ResponseWriter, r *htt
// @Summary Pause task
// @ID pause-task
// @Security CoderSessionToken
// @Accept json
// @Produce json
// @Tags Tasks
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
@@ -1325,7 +1326,7 @@ func (api *API) pauseTask(rw http.ResponseWriter, r *http.Request) {
// @Summary Resume task
// @ID resume-task
// @Security CoderSessionToken
// @Accept json
// @Produce json
// @Tags Tasks
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
+1 -1
View File
@@ -832,7 +832,7 @@ func TestTasks(t *testing.T) {
t.Run("SendToNonActiveStates", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{})
owner := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitMedium)
+94 -5
View File
@@ -135,6 +135,34 @@ const docTemplate = `{
}
}
},
"/aibridge/models": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"AI Bridge"
],
"summary": "List AI Bridge models",
"operationId": "list-ai-bridge-models",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
}
},
"/appearance": {
"get": {
"security": [
@@ -453,6 +481,34 @@ const docTemplate = `{
}
}
},
"/chats/{chat}/archive": {
"post": {
"tags": [
"Chats"
],
"summary": "Archive a chat",
"operationId": "archive-chat",
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/chats/{chat}/unarchive": {
"post": {
"tags": [
"Chats"
],
"summary": "Unarchive a chat",
"operationId": "unarchive-chat",
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/connectionlog": {
"get": {
"security": [
@@ -1747,12 +1803,17 @@ const docTemplate = `{
"summary": "Get insights about user status counts",
"operationId": "get-insights-about-user-status-counts",
"parameters": [
{
"type": "string",
"description": "IANA timezone name (e.g. America/St_Johns)",
"name": "timezone",
"in": "query"
},
{
"type": "integer",
"description": "Time-zone offset (e.g. -2)",
"description": "Deprecated: Time-zone offset (e.g. -2). Use timezone instead.",
"name": "tz_offset",
"in": "query",
"required": true
"in": "query"
}
],
"responses": {
@@ -5894,7 +5955,7 @@ const docTemplate = `{
"CoderSessionToken": []
}
],
"consumes": [
"produces": [
"application/json"
],
"tags": [
@@ -5936,7 +5997,7 @@ const docTemplate = `{
"CoderSessionToken": []
}
],
"consumes": [
"produces": [
"application/json"
],
"tags": [
@@ -8238,6 +8299,12 @@ const docTemplate = `{
"name": "user",
"in": "path",
"required": true
},
{
"type": "boolean",
"description": "Include expired tokens in the list",
"name": "include_expired",
"in": "query"
}
],
"responses": {
@@ -9545,6 +9612,7 @@ const docTemplate = `{
],
"summary": "Patch workspace agent app status",
"operationId": "patch-workspace-agent-app-status",
"deprecated": true,
"parameters": [
{
"description": "app status",
@@ -12748,6 +12816,11 @@ const docTemplate = `{
"boundary_usage:delete",
"boundary_usage:read",
"boundary_usage:update",
"chat:*",
"chat:create",
"chat:delete",
"chat:read",
"chat:update",
"coder:all",
"coder:apikeys.manage_self",
"coder:application_connect",
@@ -12952,6 +13025,11 @@ const docTemplate = `{
"APIKeyScopeBoundaryUsageDelete",
"APIKeyScopeBoundaryUsageRead",
"APIKeyScopeBoundaryUsageUpdate",
"APIKeyScopeChatAll",
"APIKeyScopeChatCreate",
"APIKeyScopeChatDelete",
"APIKeyScopeChatRead",
"APIKeyScopeChatUpdate",
"APIKeyScopeCoderAll",
"APIKeyScopeCoderApikeysManageSelf",
"APIKeyScopeCoderApplicationConnect",
@@ -14813,6 +14891,9 @@ const docTemplate = `{
"external_auth": {
"$ref": "#/definitions/serpent.Struct-array_codersdk_ExternalAuthConfig"
},
"external_auth_github_default_provider_enable": {
"type": "boolean"
},
"external_token_encryption_keys": {
"type": "array",
"items": {
@@ -15097,9 +15178,11 @@ const docTemplate = `{
"workspace-usage",
"web-push",
"oauth2",
"agents",
"mcp-server-http"
],
"x-enum-comments": {
"ExperimentAgents": "Enables agent-powered chat functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -15115,6 +15198,7 @@ const docTemplate = `{
"Enables the new workspace usage tracking.",
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables agent-powered chat functionality.",
"Enables the MCP HTTP server functionality."
],
"x-enum-varnames": [
@@ -15124,6 +15208,7 @@ const docTemplate = `{
"ExperimentWorkspaceUsage",
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentAgents",
"ExperimentMCPServerHTTP"
]
},
@@ -18064,6 +18149,7 @@ const docTemplate = `{
"assign_role",
"audit_log",
"boundary_usage",
"chat",
"connection_log",
"crypto_key",
"debug_info",
@@ -18109,6 +18195,7 @@ const docTemplate = `{
"ResourceAssignRole",
"ResourceAuditLog",
"ResourceBoundaryUsage",
"ResourceChat",
"ResourceConnectionLog",
"ResourceCryptoKey",
"ResourceDebugInfo",
@@ -19779,6 +19866,7 @@ const docTemplate = `{
"type": "string",
"enum": [
"",
"geist-mono",
"ibm-plex-mono",
"fira-code",
"source-code-pro",
@@ -19786,6 +19874,7 @@ const docTemplate = `{
],
"x-enum-varnames": [
"TerminalFontUnknown",
"TerminalFontGeistMono",
"TerminalFontIBMPlexMono",
"TerminalFontFiraCode",
"TerminalFontSourceCodePro",
+86 -5
View File
@@ -112,6 +112,30 @@
}
}
},
"/aibridge/models": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["AI Bridge"],
"summary": "List AI Bridge models",
"operationId": "list-ai-bridge-models",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
}
},
"/appearance": {
"get": {
"security": [
@@ -386,6 +410,30 @@
}
}
},
"/chats/{chat}/archive": {
"post": {
"tags": ["Chats"],
"summary": "Archive a chat",
"operationId": "archive-chat",
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/chats/{chat}/unarchive": {
"post": {
"tags": ["Chats"],
"summary": "Unarchive a chat",
"operationId": "unarchive-chat",
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/connectionlog": {
"get": {
"security": [
@@ -1527,12 +1575,17 @@
"summary": "Get insights about user status counts",
"operationId": "get-insights-about-user-status-counts",
"parameters": [
{
"type": "string",
"description": "IANA timezone name (e.g. America/St_Johns)",
"name": "timezone",
"in": "query"
},
{
"type": "integer",
"description": "Time-zone offset (e.g. -2)",
"description": "Deprecated: Time-zone offset (e.g. -2). Use timezone instead.",
"name": "tz_offset",
"in": "query",
"required": true
"in": "query"
}
],
"responses": {
@@ -5213,7 +5266,7 @@
"CoderSessionToken": []
}
],
"consumes": ["application/json"],
"produces": ["application/json"],
"tags": ["Tasks"],
"summary": "Pause task",
"operationId": "pause-task",
@@ -5251,7 +5304,7 @@
"CoderSessionToken": []
}
],
"consumes": ["application/json"],
"produces": ["application/json"],
"tags": ["Tasks"],
"summary": "Resume task",
"operationId": "resume-task",
@@ -7285,6 +7338,12 @@
"name": "user",
"in": "path",
"required": true
},
{
"type": "boolean",
"description": "Include expired tokens in the list",
"name": "include_expired",
"in": "query"
}
],
"responses": {
@@ -8444,6 +8503,7 @@
"tags": ["Agents"],
"summary": "Patch workspace agent app status",
"operationId": "patch-workspace-agent-app-status",
"deprecated": true,
"parameters": [
{
"description": "app status",
@@ -11356,6 +11416,11 @@
"boundary_usage:delete",
"boundary_usage:read",
"boundary_usage:update",
"chat:*",
"chat:create",
"chat:delete",
"chat:read",
"chat:update",
"coder:all",
"coder:apikeys.manage_self",
"coder:application_connect",
@@ -11560,6 +11625,11 @@
"APIKeyScopeBoundaryUsageDelete",
"APIKeyScopeBoundaryUsageRead",
"APIKeyScopeBoundaryUsageUpdate",
"APIKeyScopeChatAll",
"APIKeyScopeChatCreate",
"APIKeyScopeChatDelete",
"APIKeyScopeChatRead",
"APIKeyScopeChatUpdate",
"APIKeyScopeCoderAll",
"APIKeyScopeCoderApikeysManageSelf",
"APIKeyScopeCoderApplicationConnect",
@@ -13347,6 +13417,9 @@
"external_auth": {
"$ref": "#/definitions/serpent.Struct-array_codersdk_ExternalAuthConfig"
},
"external_auth_github_default_provider_enable": {
"type": "boolean"
},
"external_token_encryption_keys": {
"type": "array",
"items": {
@@ -13624,9 +13697,11 @@
"workspace-usage",
"web-push",
"oauth2",
"agents",
"mcp-server-http"
],
"x-enum-comments": {
"ExperimentAgents": "Enables agent-powered chat functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -13642,6 +13717,7 @@
"Enables the new workspace usage tracking.",
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables agent-powered chat functionality.",
"Enables the MCP HTTP server functionality."
],
"x-enum-varnames": [
@@ -13651,6 +13727,7 @@
"ExperimentWorkspaceUsage",
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentAgents",
"ExperimentMCPServerHTTP"
]
},
@@ -16476,6 +16553,7 @@
"assign_role",
"audit_log",
"boundary_usage",
"chat",
"connection_log",
"crypto_key",
"debug_info",
@@ -16521,6 +16599,7 @@
"ResourceAssignRole",
"ResourceAuditLog",
"ResourceBoundaryUsage",
"ResourceChat",
"ResourceConnectionLog",
"ResourceCryptoKey",
"ResourceDebugInfo",
@@ -18116,6 +18195,7 @@
"type": "string",
"enum": [
"",
"geist-mono",
"ibm-plex-mono",
"fira-code",
"source-code-pro",
@@ -18123,6 +18203,7 @@
],
"x-enum-varnames": [
"TerminalFontUnknown",
"TerminalFontGeistMono",
"TerminalFontIBMPlexMono",
"TerminalFontFiraCode",
"TerminalFontSourceCodePro",
+14 -8
View File
@@ -307,20 +307,26 @@ func (api *API) apiKeyByName(rw http.ResponseWriter, r *http.Request) {
// @Tags Users
// @Param user path string true "User ID, name, or me"
// @Success 200 {array} codersdk.APIKey
// @Param include_expired query bool false "Include expired tokens in the list"
// @Router /users/{user}/keys/tokens [get]
func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
user = httpmw.UserParam(r)
keys []database.APIKey
err error
queryStr = r.URL.Query().Get("include_all")
includeAll, _ = strconv.ParseBool(queryStr)
ctx = r.Context()
user = httpmw.UserParam(r)
keys []database.APIKey
err error
queryStr = r.URL.Query().Get("include_all")
includeAll, _ = strconv.ParseBool(queryStr)
expiredStr = r.URL.Query().Get("include_expired")
includeExpired, _ = strconv.ParseBool(expiredStr)
)
if includeAll {
// get tokens for all users
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.LoginTypeToken)
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.GetAPIKeysByLoginTypeParams{
LoginType: database.LoginTypeToken,
IncludeExpired: includeExpired,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching API keys.",
@@ -330,7 +336,7 @@ func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
}
} else {
// get user's tokens only
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID})
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID, IncludeExpired: includeExpired})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching API keys.",
+1 -1
View File
@@ -113,7 +113,7 @@ func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error)
return database.InsertAPIKeyParams{
ID: keyID,
UserID: params.UserID,
LastUsed: time.Time{},
LastUsed: time.Unix(0, 0).UTC(),
LifetimeSeconds: params.LifetimeSeconds,
IPAddress: pqtype.Inet{
IPNet: net.IPNet{
+38
View File
@@ -69,6 +69,44 @@ func TestTokenCRUD(t *testing.T) {
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
}
func TestTokensFilterExpired(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
adminClient := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, adminClient)
// Create a token.
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
Lifetime: time.Hour * 24 * 7,
})
require.NoError(t, err)
keyID := strings.Split(res.Key, "-")[0]
// List tokens without including expired - should see the token.
keys, err := adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Len(t, keys, 1)
// Expire the token.
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
require.NoError(t, err)
// List tokens without including expired - should NOT see expired token.
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Empty(t, keys)
// List tokens WITH including expired - should see expired token.
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{
IncludeExpired: true,
})
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, keyID, keys[0].ID)
}
func TestTokenScoped(t *testing.T) {
t.Parallel()
+48
View File
@@ -1,6 +1,7 @@
package coderd
import (
"context"
"fmt"
"net/http"
@@ -8,6 +9,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
@@ -91,6 +93,36 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action policy.Action, object
return true
}
// AuthorizeContext checks whether the RBAC subject on the context
// is authorized to perform the given action. The subject must have
// been set via dbauthz.As or the ExtractAPIKey middleware. Returns
// false if the subject is missing or unauthorized.
func (h *HTTPAuthorizer) AuthorizeContext(ctx context.Context, action policy.Action, object rbac.Objecter) bool {
roles, ok := dbauthz.ActorFromContext(ctx)
if !ok {
h.Logger.Error(ctx, "no authorization actor in context")
return false
}
err := h.Authorizer.Authorize(ctx, roles, action, object.RBACObject())
if err != nil {
internalError := new(rbac.UnauthorizedError)
logger := h.Logger
if xerrors.As(err, internalError) {
logger = h.Logger.With(slog.F("internal_error", internalError.Internal()))
}
logger.Warn(ctx, "requester is not authorized to access the object",
slog.F("roles", roles.SafeRoleNames()),
slog.F("actor_id", roles.ID),
slog.F("actor_name", roles),
slog.F("scope", roles.SafeScopeName()),
slog.F("action", action),
slog.F("object", object),
)
return false
}
return true
}
// AuthorizeSQLFilter returns an authorization filter that can used in a
// SQL 'WHERE' clause. If the filter is used, the resulting rows returned
// from postgres are already authorized, and the caller does not need to
@@ -106,6 +138,22 @@ func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action policy.Actio
return prepared, nil
}
// AuthorizeSQLFilterContext is like AuthorizeSQLFilter but reads the
// RBAC subject from the context directly rather than from an
// *http.Request. The subject must have been set via dbauthz.As.
func (h *HTTPAuthorizer) AuthorizeSQLFilterContext(ctx context.Context, action policy.Action, objectType string) (rbac.PreparedAuthorized, error) {
roles, ok := dbauthz.ActorFromContext(ctx)
if !ok {
return nil, xerrors.New("no authorization actor in context")
}
prepared, err := h.Authorizer.Prepare(ctx, roles, action, objectType)
if err != nil {
return nil, xerrors.Errorf("prepare filter: %w", err)
}
return prepared, nil
}
// checkAuthorization returns if the current API key can use the given
// permissions, factoring in the current user's roles and the API key scopes.
//
+1 -1
View File
@@ -500,7 +500,7 @@ func (e *Executor) runOnce(t time.Time) Stats {
"task": task.Name,
"task_id": task.ID.String(),
"workspace": ws.Name,
"pause_reason": "inactivity exceeded the dormancy threshold",
"pause_reason": "idle timeout",
},
"lifecycle_executor",
ws.ID, ws.OwnerID, ws.OrganizationID,
+1 -1
View File
@@ -2082,6 +2082,6 @@ func TestExecutorTaskWorkspace(t *testing.T) {
require.Equal(t, task.Name, sent[0].Labels["task"])
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
require.Equal(t, "inactivity exceeded the dormancy threshold", sent[0].Labels["pause_reason"])
require.Equal(t, "idle timeout", sent[0].Labels["pause_reason"])
})
}
File diff suppressed because it is too large Load Diff
+86
View File
@@ -0,0 +1,86 @@
package chatd
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
)
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
t.Parallel()
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
calls := 0
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(context.Context, uuid.UUID) (database.Chat, error) {
calls++
return database.Chat{}, nil
},
)
require.NoError(t, err)
require.Equal(t, chat, refreshed)
require.Equal(t, 0, calls)
}
func TestRefreshChatWorkspaceSnapshot_ReloadsWhenWorkspaceMissing(t *testing.T) {
t.Parallel()
chatID := uuid.New()
workspaceID := uuid.New()
chat := database.Chat{ID: chatID}
reloaded := database.Chat{
ID: chatID,
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
calls := 0
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(_ context.Context, id uuid.UUID) (database.Chat, error) {
calls++
require.Equal(t, chatID, id)
return reloaded, nil
},
)
require.NoError(t, err)
require.Equal(t, reloaded, refreshed)
require.Equal(t, 1, calls)
}
func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
t.Parallel()
chat := database.Chat{ID: uuid.New()}
loadErr := xerrors.New("boom")
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(context.Context, uuid.UUID) (database.Chat, error) {
return database.Chat{}, loadErr
},
)
require.Error(t, err)
require.ErrorContains(t, err, "reload chat workspace state")
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
File diff suppressed because it is too large Load Diff
+955
View File
@@ -0,0 +1,955 @@
package chatloop
import (
"context"
"database/sql"
"encoding/json"
"errors"
"slices"
"strconv"
"strings"
"time"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"charm.land/fantasy/schema"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/chatd/chatretry"
"github.com/coder/coder/v2/codersdk"
)
const (
interruptedToolResultErrorMessage = "tool call was interrupted before it produced a result"
)
var ErrInterrupted = xerrors.New("chat interrupted")
// PersistedStep contains the full content of a completed or
// interrupted agent step. Content includes both assistant blocks
// (text, reasoning, tool calls) and tool result blocks. The
// persistence layer is responsible for splitting these into
// separate database messages by role.
type PersistedStep struct {
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
}
// RunOptions configures a single streaming chat loop run.
type RunOptions struct {
Model fantasy.LanguageModel
Messages []fantasy.Message
Tools []fantasy.AgentTool
MaxSteps int
ActiveTools []string
ContextLimitFallback int64
// ModelConfig holds per-call LLM parameters (temperature,
// max tokens, etc.) read from the chat model configuration.
ModelConfig codersdk.ChatModelCallConfig
// ProviderOptions are provider-specific call options
// converted from ModelConfig.ProviderOptions. This is a
// separate field because the conversion requires knowledge
// of the provider, which lives in chatd, not chatloop.
ProviderOptions fantasy.ProviderOptions
PersistStep func(context.Context, PersistedStep) error
PublishMessagePart func(
role fantasy.MessageRole,
part codersdk.ChatMessagePart,
)
Compaction *CompactionOptions
ReloadMessages func(context.Context) ([]fantasy.Message, error)
// OnRetry is called before each retry attempt when the LLM
// stream fails with a retryable error. It provides the attempt
// number, error, and backoff delay so callers can publish status
// events to connected clients.
OnRetry chatretry.OnRetryFn
OnInterruptedPersistError func(error)
}
// stepResult holds the accumulated output of a single streaming
// step. Since we own the stream consumer, all content is tracked
// directly here — no shadow draft state needed.
type stepResult struct {
content []fantasy.Content
usage fantasy.Usage
providerMetadata fantasy.ProviderMetadata
finishReason fantasy.FinishReason
toolCalls []fantasy.ToolCallContent
shouldContinue bool
}
// toResponseMessages converts step content into messages suitable
// for appending to the conversation. Mirrors fantasy's
// toResponseMessages logic.
func (r stepResult) toResponseMessages() []fantasy.Message {
var assistantParts []fantasy.MessagePart
var toolParts []fantasy.MessagePart
for _, c := range r.content {
switch c.GetType() {
case fantasy.ContentTypeText:
text, ok := fantasy.AsContentType[fantasy.TextContent](c)
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.TextPart{
Text: text.Text,
ProviderOptions: fantasy.ProviderOptions(text.ProviderMetadata),
})
case fantasy.ContentTypeReasoning:
reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.ReasoningPart{
Text: reasoning.Text,
ProviderOptions: fantasy.ProviderOptions(reasoning.ProviderMetadata),
})
case fantasy.ContentTypeToolCall:
toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](c)
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.ToolCallPart{
ToolCallID: toolCall.ToolCallID,
ToolName: toolCall.ToolName,
Input: toolCall.Input,
ProviderExecuted: toolCall.ProviderExecuted,
ProviderOptions: fantasy.ProviderOptions(toolCall.ProviderMetadata),
})
case fantasy.ContentTypeFile:
file, ok := fantasy.AsContentType[fantasy.FileContent](c)
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.FilePart{
Data: file.Data,
MediaType: file.MediaType,
ProviderOptions: fantasy.ProviderOptions(file.ProviderMetadata),
})
case fantasy.ContentTypeSource:
// Sources are metadata about references; they don't
// need to be included in conversation messages.
continue
case fantasy.ContentTypeToolResult:
result, ok := fantasy.AsContentType[fantasy.ToolResultContent](c)
if !ok {
continue
}
toolParts = append(toolParts, fantasy.ToolResultPart{
ToolCallID: result.ToolCallID,
Output: result.Result,
ProviderOptions: fantasy.ProviderOptions(result.ProviderMetadata),
})
default:
continue
}
}
var messages []fantasy.Message
if len(assistantParts) > 0 {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: assistantParts,
})
}
if len(toolParts) > 0 {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: toolParts,
})
}
return messages
}
// reasoningState accumulates reasoning content and provider
// metadata while the stream is in flight.
type reasoningState struct {
text string
options fantasy.ProviderMetadata
}
// Run executes the chat step-stream loop and delegates
// persistence/publishing to callbacks.
func Run(ctx context.Context, opts RunOptions) error {
if opts.Model == nil {
return xerrors.New("chat model is required")
}
if opts.PersistStep == nil {
return xerrors.New("persist step callback is required")
}
if opts.MaxSteps <= 0 {
opts.MaxSteps = 1
}
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
if opts.PublishMessagePart == nil {
return
}
opts.PublishMessagePart(role, part)
}
tools := buildToolDefinitions(opts.Tools, opts.ActiveTools)
applyAnthropicCaching := shouldApplyAnthropicPromptCaching(opts.Model)
messages := opts.Messages
alreadyCompacted := false
var lastUsage fantasy.Usage
var lastProviderMetadata fantasy.ProviderMetadata
for step := 0; step < opts.MaxSteps; step++ {
// Copy messages so that provider-specific caching
// mutations don't leak back to the caller's slice.
// copy copies Message structs by value, so field
// reassignments in addAnthropicPromptCaching only
// affect the prepared slice.
prepared := make([]fantasy.Message, len(messages))
copy(prepared, messages)
if applyAnthropicCaching {
addAnthropicPromptCaching(prepared)
}
call := fantasy.Call{
Prompt: prepared,
Tools: tools,
MaxOutputTokens: opts.ModelConfig.MaxOutputTokens,
Temperature: opts.ModelConfig.Temperature,
TopP: opts.ModelConfig.TopP,
TopK: opts.ModelConfig.TopK,
PresencePenalty: opts.ModelConfig.PresencePenalty,
FrequencyPenalty: opts.ModelConfig.FrequencyPenalty,
ProviderOptions: opts.ProviderOptions,
}
var result stepResult
err := chatretry.Retry(ctx, func(retryCtx context.Context) error {
stream, streamErr := opts.Model.Stream(retryCtx, call)
if streamErr != nil {
return streamErr
}
var processErr error
result, processErr = processStepStream(retryCtx, stream, publishMessagePart)
return processErr
}, func(attempt int, retryErr error, delay time.Duration) {
// Reset result from the failed attempt so the next
// attempt starts clean.
result = stepResult{}
if opts.OnRetry != nil {
opts.OnRetry(attempt, retryErr, delay)
}
})
if err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
return ErrInterrupted
}
return xerrors.Errorf("stream response: %w", err)
}
// Execute tools before persisting so that tool results
// are included in the persisted step content. The
// persistence layer splits assistant and tool-result
// blocks into separate database messages by role.
var toolResults []fantasy.ToolResultContent
if result.shouldContinue {
// Check for context cancellation before starting
// tool execution. If the chat was interrupted
// between stream completion and here, persist
// what we have and bail out.
if ctx.Err() != nil {
if errors.Is(context.Cause(ctx), ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
return ErrInterrupted
}
return ctx.Err()
}
toolResults = executeTools(ctx, opts.Tools, result.toolCalls, func(tr fantasy.ToolResultContent) {
publishMessagePart(
fantasy.MessageRoleTool,
chatprompt.PartFromContent(tr),
)
})
for _, tr := range toolResults {
result.content = append(result.content, tr)
}
}
// Extract context limit from provider metadata.
contextLimit := extractContextLimit(result.providerMetadata)
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
contextLimit = sql.NullInt64{
Int64: opts.ContextLimitFallback,
Valid: true,
}
}
// Persist the step — errors propagate directly.
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
}); err != nil {
return xerrors.Errorf("persist step: %w", err)
}
lastUsage = result.usage
lastProviderMetadata = result.providerMetadata
// Inline compaction.
if opts.Compaction != nil && opts.ReloadMessages != nil {
did, compactErr := tryCompact(
ctx,
opts.Model,
opts.Compaction,
opts.ContextLimitFallback,
result.usage,
result.providerMetadata,
messages,
)
if compactErr != nil && opts.Compaction.OnError != nil {
opts.Compaction.OnError(compactErr)
}
if did {
alreadyCompacted = true
reloaded, reloadErr := opts.ReloadMessages(ctx)
if reloadErr != nil {
return xerrors.Errorf("reload messages after compaction: %w", reloadErr)
}
messages = reloaded
}
}
if !result.shouldContinue {
break
}
// Build messages from the step for the next iteration.
// toResponseMessages produces assistant-role content
// (text, reasoning, tool calls) and tool-result content.
stepMessages := result.toResponseMessages()
messages = append(messages, stepMessages...)
}
// Post-run compaction safety net: if we never compacted
// during the loop, try once at the end.
if !alreadyCompacted && opts.Compaction != nil {
if _, err := tryCompact(
ctx,
opts.Model,
opts.Compaction,
opts.ContextLimitFallback,
lastUsage,
lastProviderMetadata,
messages,
); err != nil {
if opts.Compaction.OnError != nil {
opts.Compaction.OnError(err)
}
}
}
return nil
}
// processStepStream consumes a fantasy StreamResponse and
// accumulates all content into a stepResult. Callbacks fire
// inline and their errors propagate directly.
func processStepStream(
ctx context.Context,
stream fantasy.StreamResponse,
publishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart),
) (stepResult, error) {
var result stepResult
activeToolCalls := make(map[string]*fantasy.ToolCallContent)
activeTextContent := make(map[string]string)
activeReasoningContent := make(map[string]reasoningState)
// Track tool names by ID for input delta publishing.
toolNames := make(map[string]string)
// Track reasoning text/titles for title extraction.
reasoningTitles := make(map[string]string)
reasoningText := make(map[string]string)
setReasoningTitleFromText := func(id string, text string) {
if id == "" || strings.TrimSpace(text) == "" {
return
}
if reasoningTitles[id] != "" {
return
}
reasoningText[id] += text
if !strings.ContainsAny(reasoningText[id], "\r\n") {
return
}
title := chatprompt.ReasoningTitleFromFirstLine(reasoningText[id])
if title == "" {
return
}
reasoningTitles[id] = title
}
for part := range stream {
switch part.Type {
case fantasy.StreamPartTypeTextStart:
activeTextContent[part.ID] = ""
case fantasy.StreamPartTypeTextDelta:
if _, exists := activeTextContent[part.ID]; exists {
activeTextContent[part.ID] += part.Delta
}
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: part.Delta,
})
case fantasy.StreamPartTypeTextEnd:
if text, exists := activeTextContent[part.ID]; exists {
result.content = append(result.content, fantasy.TextContent{
Text: text,
ProviderMetadata: part.ProviderMetadata,
})
delete(activeTextContent, part.ID)
}
case fantasy.StreamPartTypeReasoningStart:
activeReasoningContent[part.ID] = reasoningState{
text: part.Delta,
options: part.ProviderMetadata,
}
case fantasy.StreamPartTypeReasoningDelta:
if active, exists := activeReasoningContent[part.ID]; exists {
active.text += part.Delta
active.options = part.ProviderMetadata
activeReasoningContent[part.ID] = active
}
setReasoningTitleFromText(part.ID, part.Delta)
title := reasoningTitles[part.ID]
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: part.Delta,
Title: title,
})
case fantasy.StreamPartTypeReasoningEnd:
if active, exists := activeReasoningContent[part.ID]; exists {
if part.ProviderMetadata != nil {
active.options = part.ProviderMetadata
}
content := fantasy.ReasoningContent{
Text: active.text,
ProviderMetadata: active.options,
}
result.content = append(result.content, content)
delete(activeReasoningContent, part.ID)
// Derive reasoning title at end of reasoning
// block if we haven't yet.
if reasoningTitles[part.ID] == "" {
reasoningTitles[part.ID] = chatprompt.ReasoningTitleFromFirstLine(
reasoningText[part.ID],
)
}
title := reasoningTitles[part.ID]
if title != "" {
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Title: title,
})
}
}
case fantasy.StreamPartTypeToolInputStart:
activeToolCalls[part.ID] = &fantasy.ToolCallContent{
ToolCallID: part.ID,
ToolName: part.ToolCallName,
Input: "",
ProviderExecuted: part.ProviderExecuted,
}
if strings.TrimSpace(part.ToolCallName) != "" {
toolNames[part.ID] = part.ToolCallName
}
case fantasy.StreamPartTypeToolInputDelta:
if toolCall, exists := activeToolCalls[part.ID]; exists {
toolCall.Input += part.Delta
}
toolName := toolNames[part.ID]
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: part.ID,
ToolName: toolName,
ArgsDelta: part.Delta,
})
case fantasy.StreamPartTypeToolInputEnd:
// No callback needed; the full tool call arrives in
// StreamPartTypeToolCall.
case fantasy.StreamPartTypeToolCall:
tc := fantasy.ToolCallContent{
ToolCallID: part.ID,
ToolName: part.ToolCallName,
Input: part.ToolCallInput,
ProviderExecuted: part.ProviderExecuted,
ProviderMetadata: part.ProviderMetadata,
}
result.toolCalls = append(result.toolCalls, tc)
result.content = append(result.content, tc)
if strings.TrimSpace(part.ToolCallName) != "" {
toolNames[part.ID] = part.ToolCallName
}
// Clean up active tool call tracking.
delete(activeToolCalls, part.ID)
publishMessagePart(
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(tc),
)
case fantasy.StreamPartTypeSource:
sourceContent := fantasy.SourceContent{
SourceType: part.SourceType,
ID: part.ID,
URL: part.URL,
Title: part.Title,
ProviderMetadata: part.ProviderMetadata,
}
result.content = append(result.content, sourceContent)
publishMessagePart(
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(sourceContent),
)
case fantasy.StreamPartTypeFinish:
result.usage = part.Usage
result.finishReason = part.FinishReason
result.providerMetadata = part.ProviderMetadata
case fantasy.StreamPartTypeError:
// Detect interruption: context canceled with
// ErrInterrupted as the cause.
if errors.Is(part.Error, context.Canceled) &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
// Flush in-progress content so that
// persistInterruptedStep has access to partial
// text, reasoning, and tool calls that were
// still streaming when the interrupt arrived.
flushActiveState(
&result,
activeTextContent,
activeReasoningContent,
activeToolCalls,
toolNames,
)
return result, ErrInterrupted
}
return result, part.Error
}
}
result.shouldContinue = len(result.toolCalls) > 0 &&
result.finishReason == fantasy.FinishReasonToolCalls
return result, nil
}
// executeTools runs each tool call sequentially after the stream
// completes. Results are published via onResult as each tool
// finishes.
func executeTools(
ctx context.Context,
allTools []fantasy.AgentTool,
toolCalls []fantasy.ToolCallContent,
onResult func(fantasy.ToolResultContent),
) []fantasy.ToolResultContent {
if len(toolCalls) == 0 {
return nil
}
toolMap := make(map[string]fantasy.AgentTool, len(allTools))
for _, t := range allTools {
toolMap[t.Info().Name] = t
}
results := make([]fantasy.ToolResultContent, 0, len(toolCalls))
for _, tc := range toolCalls {
tr := executeSingleTool(ctx, toolMap, tc)
results = append(results, tr)
if onResult != nil {
onResult(tr)
}
}
return results
}
// executeSingleTool executes one tool call and converts the
// response into a ToolResultContent.
func executeSingleTool(
ctx context.Context,
toolMap map[string]fantasy.AgentTool,
tc fantasy.ToolCallContent,
) fantasy.ToolResultContent {
result := fantasy.ToolResultContent{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
ProviderExecuted: false,
}
tool, exists := toolMap[tc.ToolName]
if !exists {
result.Result = fantasy.ToolResultOutputContentError{
Error: xerrors.New("Tool not found: " + tc.ToolName),
}
return result
}
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: tc.ToolCallID,
Name: tc.ToolName,
Input: tc.Input,
})
if err != nil {
result.Result = fantasy.ToolResultOutputContentError{
Error: err,
}
result.ClientMetadata = resp.Metadata
return result
}
result.ClientMetadata = resp.Metadata
switch {
case resp.IsError:
result.Result = fantasy.ToolResultOutputContentError{
Error: xerrors.New(resp.Content),
}
case resp.Type == "image" || resp.Type == "media":
result.Result = fantasy.ToolResultOutputContentMedia{
Data: string(resp.Data),
MediaType: resp.MediaType,
Text: resp.Content,
}
default:
result.Result = fantasy.ToolResultOutputContentText{
Text: resp.Content,
}
}
return result
}
// flushActiveState moves any in-progress text, reasoning, and
// tool calls from the active tracking maps into result.content
// and result.toolCalls. This is called on interruption so that
// partial content from an incomplete stream is available for
// persistence.
func flushActiveState(
result *stepResult,
activeText map[string]string,
activeReasoning map[string]reasoningState,
activeToolCalls map[string]*fantasy.ToolCallContent,
toolNames map[string]string,
) {
// Flush partial text content.
for _, text := range activeText {
if text != "" {
result.content = append(result.content, fantasy.TextContent{Text: text})
}
}
// Flush partial reasoning content.
for _, rs := range activeReasoning {
if rs.text != "" {
result.content = append(result.content, fantasy.ReasoningContent{
Text: rs.text,
ProviderMetadata: rs.options,
})
}
}
// Flush in-progress tool calls. These haven't received a
// StreamPartTypeToolCall yet, so they only exist in
// activeToolCalls. We add them to both content and toolCalls
// so persistInterruptedStep can generate synthetic error
// results for them.
for id, tc := range activeToolCalls {
if tc == nil {
continue
}
// Prefer the tool name from the toolNames map since
// ToolInputStart may provide a cleaner name.
toolName := tc.ToolName
if name, ok := toolNames[id]; ok && strings.TrimSpace(name) != "" {
toolName = name
}
flushed := fantasy.ToolCallContent{
ToolCallID: tc.ToolCallID,
ToolName: toolName,
Input: tc.Input,
ProviderExecuted: tc.ProviderExecuted,
}
result.content = append(result.content, flushed)
result.toolCalls = append(result.toolCalls, flushed)
}
}
// persistInterruptedStep saves all accumulated content from a
// partial stream. Since we own the stepResult directly, no shadow
// state is needed.
func persistInterruptedStep(
ctx context.Context,
opts RunOptions,
result *stepResult,
) {
if result == nil || (len(result.content) == 0 && len(result.toolCalls) == 0) {
return
}
// Track which tool calls already have results in the content.
answeredToolCalls := make(map[string]struct{})
for _, c := range result.content {
tr, ok := fantasy.AsContentType[fantasy.ToolResultContent](c)
if ok && tr.ToolCallID != "" {
answeredToolCalls[tr.ToolCallID] = struct{}{}
}
}
// Build combined content: all accumulated content + synthetic
// interrupted results for any unanswered tool calls.
content := make([]fantasy.Content, 0, len(result.content))
content = append(content, result.content...)
for _, tc := range result.toolCalls {
if tc.ToolCallID == "" {
continue
}
if _, exists := answeredToolCalls[tc.ToolCallID]; exists {
continue
}
content = append(content, fantasy.ToolResultContent{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Result: fantasy.ToolResultOutputContentError{
Error: xerrors.New(interruptedToolResultErrorMessage),
},
})
answeredToolCalls[tc.ToolCallID] = struct{}{}
}
persistCtx := context.WithoutCancel(ctx)
if err := opts.PersistStep(persistCtx, PersistedStep{
Content: content,
}); err != nil {
if opts.OnInterruptedPersistError != nil {
opts.OnInterruptedPersistError(err)
}
}
}
// buildToolDefinitions converts AgentTool definitions into the
// fantasy.Tool slice expected by fantasy.Call. When activeTools
// is non-empty, only tools whose name appears in the list are
// included. This mirrors fantasy's agent.prepareTools filtering.
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string) []fantasy.Tool {
prepared := make([]fantasy.Tool, 0, len(tools))
for _, tool := range tools {
info := tool.Info()
if len(activeTools) > 0 && !slices.Contains(activeTools, info.Name) {
continue
}
inputSchema := map[string]any{
"type": "object",
"properties": info.Parameters,
"required": info.Required,
}
schema.Normalize(inputSchema)
prepared = append(prepared, fantasy.FunctionTool{
Name: info.Name,
Description: info.Description,
InputSchema: inputSchema,
ProviderOptions: tool.ProviderOptions(),
})
}
return prepared
}
func shouldApplyAnthropicPromptCaching(model fantasy.LanguageModel) bool {
if model == nil {
return false
}
return model.Provider() == fantasyanthropic.Name
}
// addAnthropicPromptCaching mutates messages in-place, setting
// ProviderOptions for Anthropic prompt caching on the last system
// message and the final two messages.
func addAnthropicPromptCaching(messages []fantasy.Message) {
for i := range messages {
messages[i].ProviderOptions = nil
}
providerOption := fantasy.ProviderOptions{
fantasyanthropic.Name: &fantasyanthropic.ProviderCacheControlOptions{
CacheControl: fantasyanthropic.CacheControl{Type: "ephemeral"},
},
}
lastSystemRoleIdx := -1
systemMessageUpdated := false
for i, msg := range messages {
if msg.Role == fantasy.MessageRoleSystem {
lastSystemRoleIdx = i
} else if !systemMessageUpdated && lastSystemRoleIdx >= 0 {
messages[lastSystemRoleIdx].ProviderOptions = providerOption
systemMessageUpdated = true
}
if i > len(messages)-3 {
messages[i].ProviderOptions = providerOption
}
}
}
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
if len(metadata) == 0 {
return sql.NullInt64{}
}
encoded, err := json.Marshal(metadata)
if err != nil || len(encoded) == 0 {
return sql.NullInt64{}
}
var payload any
if err := json.Unmarshal(encoded, &payload); err != nil {
return sql.NullInt64{}
}
limit, ok := findContextLimitValue(payload)
if !ok {
return sql.NullInt64{}
}
return sql.NullInt64{
Int64: limit,
Valid: true,
}
}
func findContextLimitValue(value any) (int64, bool) {
var (
limit int64
found bool
)
collectContextLimitValues(value, func(candidate int64) {
if !found || candidate > limit {
limit = candidate
found = true
}
})
return limit, found
}
func collectContextLimitValues(value any, onValue func(int64)) {
switch typed := value.(type) {
case map[string]any:
for key, child := range typed {
if isContextLimitKey(key) {
if numeric, ok := numericContextLimitValue(child); ok {
onValue(numeric)
}
}
collectContextLimitValues(child, onValue)
}
case []any:
for _, child := range typed {
collectContextLimitValues(child, onValue)
}
}
}
func isContextLimitKey(key string) bool {
normalized := normalizeMetadataKey(key)
if normalized == "" {
return false
}
switch normalized {
case
"contextlimit",
"contextwindow",
"contextlength",
"maxcontext",
"maxcontexttokens",
"maxinputtokens",
"maxinputtoken",
"inputtokenlimit":
return true
}
return strings.Contains(normalized, "context") &&
(strings.Contains(normalized, "limit") ||
strings.Contains(normalized, "window") ||
strings.Contains(normalized, "length") ||
strings.HasPrefix(normalized, "max"))
}
func normalizeMetadataKey(key string) string {
var b strings.Builder
b.Grow(len(key))
for _, r := range key {
switch {
case r >= 'a' && r <= 'z':
_, _ = b.WriteRune(r)
case r >= 'A' && r <= 'Z':
_, _ = b.WriteRune(r + ('a' - 'A'))
case r >= '0' && r <= '9':
_, _ = b.WriteRune(r)
}
}
return b.String()
}
func numericContextLimitValue(value any) (int64, bool) {
switch typed := value.(type) {
case int64:
return positiveInt64(typed)
case int32:
return positiveInt64(int64(typed))
case int:
return positiveInt64(int64(typed))
case float64:
casted := int64(typed)
if typed > 0 && float64(casted) == typed {
return casted, true
}
case string:
parsed, err := strconv.ParseInt(strings.TrimSpace(typed), 10, 64)
if err == nil {
return positiveInt64(parsed)
}
case json.Number:
parsed, err := typed.Int64()
if err == nil {
return positiveInt64(parsed)
}
}
return 0, false
}
func positiveInt64(value int64) (int64, bool) {
if value <= 0 {
return 0, false
}
return value, true
}
+420
View File
@@ -0,0 +1,420 @@
package chatloop //nolint:testpackage // Uses internal symbols.
import (
"context"
"iter"
"strings"
"sync"
"testing"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
const activeToolName = "read_file"
func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
t.Parallel()
var capturedCall fantasy.Call
model := &loopTestModel{
provider: fantasyanthropic.Name,
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
capturedCall = call
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
persistStepCalls := 0
var persistedStep PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleSystem, "sys-1"),
textMessage(fantasy.MessageRoleSystem, "sys-2"),
textMessage(fantasy.MessageRoleUser, "hello"),
textMessage(fantasy.MessageRoleAssistant, "working"),
textMessage(fantasy.MessageRoleUser, "continue"),
},
Tools: []fantasy.AgentTool{
newNoopTool(activeToolName),
newNoopTool("write_file"),
},
MaxSteps: 3,
ActiveTools: []string{activeToolName},
ContextLimitFallback: 4096,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistStepCalls++
persistedStep = step
return nil
},
})
require.NoError(t, err)
require.Equal(t, 1, persistStepCalls)
require.True(t, persistedStep.ContextLimit.Valid)
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
require.NotEmpty(t, capturedCall.Prompt)
require.False(t, containsPromptSentinel(capturedCall.Prompt))
require.Len(t, capturedCall.Tools, 1)
require.Equal(t, activeToolName, capturedCall.Tools[0].GetName())
require.Len(t, capturedCall.Prompt, 5)
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[0]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[1]))
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[2]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[3]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4]))
}
func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
t.Parallel()
started := make(chan struct{})
model := &loopTestModel{
provider: "fake",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
parts := []fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeToolInputStart,
ID: "interrupt-tool-1",
ToolCallName: "read_file",
},
{
Type: fantasy.StreamPartTypeToolInputDelta,
ID: "interrupt-tool-1",
ToolCallName: "read_file",
Delta: `{"path":"main.go"`,
},
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "partial assistant output"},
}
for _, part := range parts {
if !yield(part) {
return
}
}
select {
case <-started:
default:
close(started)
}
<-ctx.Done()
_ = yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: ctx.Err(),
})
}), nil
},
}
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(nil)
go func() {
<-started
cancel(ErrInterrupted)
}()
persistedAssistantCtxErr := xerrors.New("unset")
var persistedContent []fantasy.Content
err := Run(ctx, RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
},
MaxSteps: 3,
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
persistedAssistantCtxErr = persistCtx.Err()
persistedContent = append([]fantasy.Content(nil), step.Content...)
return nil
},
})
require.ErrorIs(t, err, ErrInterrupted)
require.NoError(t, persistedAssistantCtxErr)
require.NotEmpty(t, persistedContent)
var (
foundText bool
foundToolCall bool
foundToolResult bool
)
for _, block := range persistedContent {
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
if strings.Contains(text.Text, "partial assistant output") {
foundText = true
}
continue
}
if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok {
if toolCall.ToolCallID == "interrupt-tool-1" &&
toolCall.ToolName == "read_file" &&
strings.Contains(toolCall.Input, `"path":"main.go"`) {
foundToolCall = true
}
continue
}
if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
if toolResult.ToolCallID == "interrupt-tool-1" &&
toolResult.ToolName == "read_file" {
_, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError)
require.True(t, isErr, "interrupted tool result should be an error")
foundToolResult = true
}
}
}
require.True(t, foundText)
require.True(t, foundToolCall)
require.True(t, foundToolResult)
}
type loopTestModel struct {
provider string
model string
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
streamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
}
func (m *loopTestModel) Provider() string {
if m.provider != "" {
return m.provider
}
return "fake"
}
func (m *loopTestModel) Model() string {
if m.model != "" {
return m.model
}
return "fake"
}
func (m *loopTestModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
if m.generateFn != nil {
return m.generateFn(ctx, call)
}
return &fantasy.Response{}, nil
}
func (m *loopTestModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
if m.streamFn != nil {
return m.streamFn(ctx, call)
}
return streamFromParts([]fantasy.StreamPart{{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
}}), nil
}
func (*loopTestModel) GenerateObject(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return nil, xerrors.New("not implemented")
}
func (*loopTestModel) StreamObject(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
return nil, xerrors.New("not implemented")
}
func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
for _, part := range parts {
if !yield(part) {
return
}
}
})
}
func newNoopTool(name string) fantasy.AgentTool {
return fantasy.NewAgentTool(
name,
"test noop tool",
func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.ToolResponse{}, nil
},
)
}
func textMessage(role fantasy.MessageRole, text string) fantasy.Message {
return fantasy.Message{
Role: role,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: text},
},
}
}
func containsPromptSentinel(prompt []fantasy.Message) bool {
for _, message := range prompt {
if message.Role != fantasy.MessageRoleUser || len(message.Content) != 1 {
continue
}
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
if !ok {
continue
}
if strings.HasPrefix(textPart.Text, "__chatd_agent_prompt_sentinel_") {
return true
}
}
return false
}
func TestRun_MultiStepToolExecution(t *testing.T) {
t.Parallel()
var mu sync.Mutex
var streamCalls int
var secondCallPrompt []fantasy.Message
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCalls
streamCalls++
mu.Unlock()
switch step {
case 0:
// Step 0: produce a tool call.
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"path":"main.go"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "read_file",
ToolCallInput: `{"path":"main.go"}`,
},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
}), nil
default:
// Step 1: capture the prompt the loop sent us,
// then return plain text.
mu.Lock()
secondCallPrompt = append([]fantasy.Message(nil), call.Prompt...)
mu.Unlock()
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "all done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
}
},
}
var persistStepCalls int
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "please read main.go"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
},
MaxSteps: 5,
PersistStep: func(_ context.Context, _ PersistedStep) error {
persistStepCalls++
return nil
},
})
require.NoError(t, err)
// Stream was called twice: once for the tool-call step,
// once for the follow-up text step.
require.Equal(t, 2, streamCalls)
// PersistStep is called once per step.
require.Equal(t, 2, persistStepCalls)
// The second call's prompt must contain the assistant message
// from step 0 (with the tool call) and a tool-result message.
require.NotEmpty(t, secondCallPrompt)
var foundAssistantToolCall bool
var foundToolResult bool
for _, msg := range secondCallPrompt {
if msg.Role == fantasy.MessageRoleAssistant {
for _, part := range msg.Content {
if tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part); ok {
if tc.ToolCallID == "tc-1" && tc.ToolName == "read_file" {
foundAssistantToolCall = true
}
}
}
}
if msg.Role == fantasy.MessageRoleTool {
for _, part := range msg.Content {
if tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part); ok {
if tr.ToolCallID == "tc-1" {
foundToolResult = true
}
}
}
}
}
require.True(t, foundAssistantToolCall, "second call prompt should contain assistant tool call from step 0")
require.True(t, foundToolResult, "second call prompt should contain tool result message")
}
func TestRun_PersistStepErrorPropagates(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
persistErr := xerrors.New("database write failed")
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return persistErr
},
})
require.Error(t, err)
require.ErrorContains(t, err, "database write failed")
}
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
if len(message.ProviderOptions) == 0 {
return false
}
options, ok := message.ProviderOptions[fantasyanthropic.Name]
if !ok {
return false
}
cacheOptions, ok := options.(*fantasyanthropic.ProviderCacheControlOptions)
return ok && cacheOptions.CacheControl.Type == "ephemeral"
}
+318
View File
@@ -0,0 +1,318 @@
package chatloop
import (
"context"
"encoding/json"
"strings"
"time"
"charm.land/fantasy"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
const (
defaultCompactionThresholdPercent = int32(70)
minCompactionThresholdPercent = int32(0)
maxCompactionThresholdPercent = int32(100)
defaultCompactionSummaryPrompt = "Summarize the current chat so a " +
"new assistant can continue seamlessly. Include the user's goals, " +
"decisions made, concrete technical details (files, commands, APIs), " +
"errors encountered and fixes, and open questions. Be dense and factual. " +
"Omit pleasantries and next-step suggestions."
defaultCompactionSystemSummaryPrefix = "Summary of earlier chat context:"
defaultCompactionTimeout = 90 * time.Second
)
type CompactionOptions struct {
ThresholdPercent int32
ContextLimit int64
SummaryPrompt string
SystemSummaryPrefix string
Timeout time.Duration
Persist func(context.Context, CompactionResult) error
// ToolCallID and ToolName identify the synthetic tool call
// used to represent compaction in the message stream.
ToolCallID string
ToolName string
// PublishMessagePart publishes streaming parts to connected
// clients so they see "Summarizing..." / "Summarized" UI
// transitions during compaction.
PublishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart)
OnError func(error)
}
type CompactionResult struct {
SystemSummary string
SummaryReport string
ThresholdPercent int32
UsagePercent float64
ContextTokens int64
ContextLimit int64
}
// tryCompact checks whether context usage exceeds the compaction
// threshold and, if so, generates and persists a summary. Returns
// (true, nil) when compaction was performed, (false, nil) when not
// needed, and (false, err) on failure.
func tryCompact(
ctx context.Context,
model fantasy.LanguageModel,
compaction *CompactionOptions,
contextLimitFallback int64,
stepUsage fantasy.Usage,
stepMetadata fantasy.ProviderMetadata,
allMessages []fantasy.Message,
) (bool, error) {
config, ok := normalizedCompactionConfig(compaction)
if !ok {
return false, nil
}
contextTokens := contextTokensFromUsage(stepUsage)
if contextTokens <= 0 {
return false, nil
}
metadataLimit := extractContextLimit(stepMetadata)
contextLimit := resolveContextLimit(
metadataLimit.Int64,
config.ContextLimit,
contextLimitFallback,
)
usagePercent, compact := shouldCompact(
contextTokens, contextLimit, config.ThresholdPercent,
)
if !compact {
return false, nil
}
// Publish the "Summarizing..." tool-call indicator so
// connected clients see activity during summary generation.
if config.PublishMessagePart != nil && config.ToolCallID != "" {
config.PublishMessagePart(
fantasy.MessageRoleAssistant,
codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: config.ToolCallID,
ToolName: config.ToolName,
},
)
}
summary, err := generateCompactionSummary(
ctx, model, allMessages, config,
)
if err != nil {
return false, err
}
if summary == "" {
// Publish a tool-result error so connected clients
// see the compaction failure.
publishCompactionError(config, "compaction produced an empty summary")
return false, xerrors.New("compaction produced an empty summary")
}
systemSummary := strings.TrimSpace(
config.SystemSummaryPrefix + "\n\n" + summary,
)
err = config.Persist(ctx, CompactionResult{
SystemSummary: systemSummary,
SummaryReport: summary,
ThresholdPercent: config.ThresholdPercent,
UsagePercent: usagePercent,
ContextTokens: contextTokens,
ContextLimit: contextLimit,
})
if err != nil {
publishCompactionError(config, "failed to persist compaction result")
return false, xerrors.Errorf("persist compaction: %w", err)
}
// Publish the "Summarized" tool-result part so the client
// transitions from the in-progress indicator to the final
// state.
if config.PublishMessagePart != nil && config.ToolCallID != "" {
resultJSON, _ := json.Marshal(map[string]any{
"summary": summary,
"source": "automatic",
"threshold_percent": config.ThresholdPercent,
"usage_percent": usagePercent,
"context_tokens": contextTokens,
"context_limit_tokens": contextLimit,
})
config.PublishMessagePart(
fantasy.MessageRoleTool,
codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: config.ToolCallID,
ToolName: config.ToolName,
Result: resultJSON,
},
)
}
return true, nil
}
// publishCompactionError sends a tool-result error part so
// connected clients see that compaction failed.
func publishCompactionError(config CompactionOptions, msg string) {
if config.PublishMessagePart == nil || config.ToolCallID == "" {
return
}
errJSON, _ := json.Marshal(map[string]any{
"error": msg,
})
config.PublishMessagePart(
fantasy.MessageRoleTool,
codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: config.ToolCallID,
ToolName: config.ToolName,
Result: errJSON,
IsError: true,
},
)
}
// normalizedCompactionConfig returns a copy of the compaction options
// with defaults applied. The bool is false when compaction is
// disabled (nil options, missing Persist callback, or threshold at
// 100%).
func normalizedCompactionConfig(opts *CompactionOptions) (CompactionOptions, bool) {
if opts == nil {
return CompactionOptions{}, false
}
config := *opts
if config.Persist == nil {
return CompactionOptions{}, false
}
if strings.TrimSpace(config.SummaryPrompt) == "" {
config.SummaryPrompt = defaultCompactionSummaryPrompt
}
if strings.TrimSpace(config.SystemSummaryPrefix) == "" {
config.SystemSummaryPrefix = defaultCompactionSystemSummaryPrefix
}
if config.Timeout <= 0 {
config.Timeout = defaultCompactionTimeout
}
if config.ThresholdPercent < minCompactionThresholdPercent ||
config.ThresholdPercent > maxCompactionThresholdPercent {
config.ThresholdPercent = defaultCompactionThresholdPercent
}
if config.ThresholdPercent == maxCompactionThresholdPercent {
return CompactionOptions{}, false
}
return config, true
}
// contextTokensFromUsage returns the total context token count from
// a step's usage report. It sums input, cache-read, and
// cache-creation tokens when available, falling back to TotalTokens
// if none of the granular fields are set.
func contextTokensFromUsage(usage fantasy.Usage) int64 {
total := int64(0)
hasContextTokens := false
if usage.InputTokens > 0 {
total += usage.InputTokens
hasContextTokens = true
}
if usage.CacheReadTokens > 0 {
total += usage.CacheReadTokens
hasContextTokens = true
}
if usage.CacheCreationTokens > 0 {
total += usage.CacheCreationTokens
hasContextTokens = true
}
if !hasContextTokens && usage.TotalTokens > 0 {
total = usage.TotalTokens
}
return total
}
// resolveContextLimit picks the first positive value from metadata,
// configured limit, and fallback — in that priority order. Returns
// 0 when none are positive.
func resolveContextLimit(metadataLimit, configLimit, fallback int64) int64 {
if metadataLimit > 0 {
return metadataLimit
}
if configLimit > 0 {
return configLimit
}
if fallback > 0 {
return fallback
}
return 0
}
// shouldCompact returns the usage percentage and whether it exceeds
// the threshold. Returns (0, false) when contextLimit is
// non-positive.
func shouldCompact(contextTokens, contextLimit int64, thresholdPercent int32) (float64, bool) {
if contextLimit <= 0 {
return 0, false
}
usagePercent := (float64(contextTokens) / float64(contextLimit)) * 100
return usagePercent, usagePercent >= float64(thresholdPercent)
}
// generateCompactionSummary asks the model to summarize the
// conversation so far. The provided messages should contain the
// complete history (system prompt, user/assistant turns, tool
// results). A final user message with the summary prompt is appended
// before calling the model.
func generateCompactionSummary(
ctx context.Context,
model fantasy.LanguageModel,
messages []fantasy.Message,
options CompactionOptions,
) (string, error) {
summaryPrompt := make([]fantasy.Message, 0, len(messages)+1)
summaryPrompt = append(summaryPrompt, messages...)
summaryPrompt = append(summaryPrompt, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: options.SummaryPrompt},
},
})
toolChoice := fantasy.ToolChoiceNone
summaryCtx, cancel := context.WithTimeout(ctx, options.Timeout)
defer cancel()
response, err := model.Generate(summaryCtx, fantasy.Call{
Prompt: summaryPrompt,
ToolChoice: &toolChoice,
})
if err != nil {
return "", xerrors.Errorf("generate summary text: %w", err)
}
parts := make([]string, 0, len(response.Content))
for _, block := range response.Content {
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
if !ok {
continue
}
text := strings.TrimSpace(textBlock.Text)
if text == "" {
continue
}
parts = append(parts, text)
}
return strings.TrimSpace(strings.Join(parts, " ")), nil
}
+465
View File
@@ -0,0 +1,465 @@
package chatloop //nolint:testpackage // Uses internal symbols.
import (
"context"
"sync"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
func TestRun_Compaction(t *testing.T) {
t.Parallel()
t.Run("PersistsWhenThresholdReached", func(t *testing.T) {
t.Parallel()
persistCompactionCalls := 0
var persistedCompaction CompactionResult
const summaryText = "summary text for compaction"
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 80,
TotalTokens: 85,
},
},
}), nil
},
generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
require.NotEmpty(t, call.Prompt)
lastPrompt := call.Prompt[len(call.Prompt)-1]
require.Equal(t, fantasy.MessageRoleUser, lastPrompt.Role)
require.Len(t, lastPrompt.Content, 1)
instruction, ok := fantasy.AsMessagePart[fantasy.TextPart](lastPrompt.Content[0])
require.True(t, ok)
require.Equal(t, "summarize now", instruction.Text)
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
},
}, nil
},
}
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
SummaryPrompt: "summarize now",
Persist: func(_ context.Context, result CompactionResult) error {
persistCompactionCalls++
persistedCompaction = result
return nil
},
},
})
require.NoError(t, err)
require.Equal(t, 1, persistCompactionCalls)
require.Contains(t, persistedCompaction.SystemSummary, summaryText)
require.Equal(t, summaryText, persistedCompaction.SummaryReport)
require.Equal(t, int64(80), persistedCompaction.ContextTokens)
require.Equal(t, int64(100), persistedCompaction.ContextLimit)
require.InDelta(t, 80.0, persistedCompaction.UsagePercent, 0.0001)
})
t.Run("PublishesPartsBeforeAndAfterPersist", func(t *testing.T) {
t.Parallel()
const summaryText = "compaction summary for ordering test"
// Track the order of callbacks to verify the tool-call
// part publishes before Generate (summary generation)
// and the tool-result part publishes after Persist.
var callOrder []string
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 80,
TotalTokens: 85,
},
},
}), nil
},
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
callOrder = append(callOrder, "generate")
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
},
}, nil
},
}
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
SummaryPrompt: "summarize now",
ToolCallID: "test-tool-call-id",
ToolName: "chat_summarized",
PublishMessagePart: func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
switch part.Type {
case codersdk.ChatMessagePartTypeToolCall:
callOrder = append(callOrder, "publish_tool_call")
case codersdk.ChatMessagePartTypeToolResult:
callOrder = append(callOrder, "publish_tool_result")
}
},
Persist: func(_ context.Context, _ CompactionResult) error {
callOrder = append(callOrder, "persist")
return nil
},
},
})
require.NoError(t, err)
require.Equal(t, []string{
"publish_tool_call",
"generate",
"persist",
"publish_tool_result",
}, callOrder)
})
t.Run("PublishNotCalledBelowThreshold", func(t *testing.T) {
t.Parallel()
publishCalled := false
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 10,
},
},
}), nil
},
}
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
ToolCallID: "test-tool-call-id",
ToolName: "chat_summarized",
PublishMessagePart: func(_ fantasy.MessageRole, _ codersdk.ChatMessagePart) {
publishCalled = true
},
Persist: func(_ context.Context, _ CompactionResult) error {
return nil
},
},
})
require.NoError(t, err)
require.False(t, publishCalled, "PublishMessagePart should not fire when usage is below threshold")
})
t.Run("MidLoopCompactionReloadsMessages", func(t *testing.T) {
t.Parallel()
var mu sync.Mutex
var streamCallCount int
persistCompactionCalls := 0
reloadCalls := 0
const summaryText = "compacted summary"
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCallCount
streamCallCount++
mu.Unlock()
switch step {
case 0:
// Step 0: tool call with high usage (80/100 = 80% > 70%).
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "read_file",
ToolCallInput: `{}`,
},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonToolCalls,
Usage: fantasy.Usage{
InputTokens: 80,
TotalTokens: 85,
},
},
}), nil
default:
// Step 1: text with low usage (30/100 = 30% < 70%).
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 30,
TotalTokens: 35,
},
},
}), nil
}
},
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
},
}, nil
},
}
compactedMessages := []fantasy.Message{
textMessage(fantasy.MessageRoleSystem, "compacted system"),
textMessage(fantasy.MessageRoleUser, "compacted user"),
}
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
},
MaxSteps: 5,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
SummaryPrompt: "summarize now",
Persist: func(_ context.Context, _ CompactionResult) error {
persistCompactionCalls++
return nil
},
},
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
reloadCalls++
return compactedMessages, nil
},
})
require.NoError(t, err)
// Compaction fired after step 0 (above threshold).
require.GreaterOrEqual(t, persistCompactionCalls, 1)
// ReloadMessages was called after mid-loop compaction.
require.GreaterOrEqual(t, reloadCalls, 1)
// Both steps ran (tool-call step + follow-up text step).
require.Equal(t, 2, streamCallCount)
})
t.Run("PostRunCompactionSkippedAfterMidLoop", func(t *testing.T) {
t.Parallel()
var mu sync.Mutex
var streamCallCount int
persistCompactionCalls := 0
const summaryText = "compacted summary for skip test"
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCallCount
streamCallCount++
mu.Unlock()
switch step {
case 0:
// Step 0: tool call with high usage (80/100 = 80% > 70%).
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "read_file",
ToolCallInput: `{}`,
},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonToolCalls,
Usage: fantasy.Usage{
InputTokens: 80,
TotalTokens: 85,
},
},
}), nil
default:
// Step 1: text with low usage (20/100 = 20% < 70%).
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 20,
TotalTokens: 25,
},
},
}), nil
}
},
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
},
}, nil
},
}
compactedMessages := []fantasy.Message{
textMessage(fantasy.MessageRoleSystem, "compacted system"),
textMessage(fantasy.MessageRoleUser, "compacted user"),
}
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
},
MaxSteps: 5,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
SummaryPrompt: "summarize now",
Persist: func(_ context.Context, _ CompactionResult) error {
persistCompactionCalls++
return nil
},
},
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
return compactedMessages, nil
},
})
require.NoError(t, err)
// Only mid-loop compaction fires after step 0. The post-run
// safety net is skipped because alreadyCompacted is true.
require.Equal(t, 1, persistCompactionCalls)
})
t.Run("ErrorsAreReported", func(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 80,
},
},
}), nil
},
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return nil, xerrors.New("generate failed")
},
}
compactionErr := xerrors.New("unset")
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
Persist: func(_ context.Context, _ CompactionResult) error {
return nil
},
OnError: func(err error) {
compactionErr = err
},
},
})
require.NoError(t, err)
require.Error(t, compactionErr)
require.ErrorContains(t, compactionErr, "generate summary text")
})
}
+982
View File
@@ -0,0 +1,982 @@
package chatprompt
import (
"encoding/json"
"regexp"
"strings"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
)
var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
func ConvertMessages(
messages []database.ChatMessage,
) ([]fantasy.Message, error) {
prompt := make([]fantasy.Message, 0, len(messages))
toolNameByCallID := make(map[string]string)
for _, message := range messages {
visibility := message.Visibility
if visibility == "" {
visibility = database.ChatMessageVisibilityBoth
}
if visibility != database.ChatMessageVisibilityModel &&
visibility != database.ChatMessageVisibilityBoth {
continue
}
switch message.Role {
case string(fantasy.MessageRoleSystem):
content, err := parseSystemContent(message.Content)
if err != nil {
return nil, err
}
if strings.TrimSpace(content) == "" {
continue
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: content},
},
})
case string(fantasy.MessageRoleUser):
content, err := ParseContent(string(fantasy.MessageRoleUser), message.Content)
if err != nil {
return nil, err
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: ToMessageParts(content),
})
case string(fantasy.MessageRoleAssistant):
content, err := ParseContent(string(fantasy.MessageRoleAssistant), message.Content)
if err != nil {
return nil, err
}
parts := normalizeAssistantToolCallInputs(ToMessageParts(content))
for _, toolCall := range ExtractToolCalls(parts) {
if toolCall.ToolCallID == "" || strings.TrimSpace(toolCall.ToolName) == "" {
continue
}
toolNameByCallID[sanitizeToolCallID(toolCall.ToolCallID)] = toolCall.ToolName
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: parts,
})
case string(fantasy.MessageRoleTool):
rows, err := parseToolResultRows(message.Content)
if err != nil {
return nil, err
}
parts := make([]fantasy.MessagePart, 0, len(rows))
for _, row := range rows {
if row.ToolCallID != "" && row.ToolName != "" {
toolNameByCallID[sanitizeToolCallID(row.ToolCallID)] = row.ToolName
}
parts = append(parts, row.toToolResultPart())
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: parts,
})
default:
return nil, xerrors.Errorf("unsupported chat message role %q", message.Role)
}
}
prompt = injectMissingToolResults(prompt)
prompt = injectMissingToolUses(
prompt,
toolNameByCallID,
)
return prompt, nil
}
// PrependSystem prepends a system message unless an existing system
// message already mentions create_workspace guidance.
func PrependSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
instruction = strings.TrimSpace(instruction)
if instruction == "" {
return prompt
}
for _, message := range prompt {
if message.Role != fantasy.MessageRoleSystem {
continue
}
for _, part := range message.Content {
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
if !ok {
continue
}
if strings.Contains(strings.ToLower(textPart.Text), "create_workspace") {
return prompt
}
}
}
out := make([]fantasy.Message, 0, len(prompt)+1)
out = append(out, fantasy.Message{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: instruction},
},
})
out = append(out, prompt...)
return out
}
// InsertSystem inserts a system message after the existing system
// block and before the first non-system message.
func InsertSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
instruction = strings.TrimSpace(instruction)
if instruction == "" {
return prompt
}
systemMessage := fantasy.Message{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: instruction},
},
}
out := make([]fantasy.Message, 0, len(prompt)+1)
inserted := false
for _, message := range prompt {
if !inserted && message.Role != fantasy.MessageRoleSystem {
out = append(out, systemMessage)
inserted = true
}
out = append(out, message)
}
if !inserted {
out = append(out, systemMessage)
}
return out
}
// AppendUser appends an instruction as a user message at the end of
// the prompt.
func AppendUser(prompt []fantasy.Message, instruction string) []fantasy.Message {
instruction = strings.TrimSpace(instruction)
if instruction == "" {
return prompt
}
out := make([]fantasy.Message, 0, len(prompt)+1)
out = append(out, prompt...)
out = append(out, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: instruction},
},
})
return out
}
// ParseContent decodes persisted chat message content blocks.
func ParseContent(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
var text string
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
return []fantasy.Content{fantasy.TextContent{Text: text}}, nil
}
var rawBlocks []json.RawMessage
if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil {
return nil, xerrors.Errorf("parse %s content: %w", role, err)
}
content := make([]fantasy.Content, 0, len(rawBlocks))
for i, rawBlock := range rawBlocks {
block, err := fantasy.UnmarshalContent(rawBlock)
if err != nil {
return nil, xerrors.Errorf("parse %s content block %d: %w", role, i, err)
}
content = append(content, block)
}
return content, nil
}
// toolResultRaw is an untyped representation of a persisted tool
// result row. We intentionally avoid a strict Go struct so that
// historical shapes are never rejected.
type toolResultRaw struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Result json.RawMessage `json:"result"`
IsError bool `json:"is_error,omitempty"`
}
// parseToolResultRows decodes persisted tool result rows.
func parseToolResultRows(raw pqtype.NullRawMessage) ([]toolResultRaw, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
var rows []toolResultRaw
if err := json.Unmarshal(raw.RawMessage, &rows); err != nil {
return nil, xerrors.Errorf("parse tool content: %w", err)
}
return rows, nil
}
func (r toolResultRaw) toToolResultPart() fantasy.ToolResultPart {
toolCallID := sanitizeToolCallID(r.ToolCallID)
resultText := string(r.Result)
if resultText == "" || resultText == "null" {
resultText = "{}"
}
if r.IsError {
message := strings.TrimSpace(resultText)
if extracted := extractErrorString(r.Result); extracted != "" {
message = extracted
}
return fantasy.ToolResultPart{
ToolCallID: toolCallID,
Output: fantasy.ToolResultOutputContentError{
Error: xerrors.New(message),
},
}
}
return fantasy.ToolResultPart{
ToolCallID: toolCallID,
Output: fantasy.ToolResultOutputContentText{
Text: resultText,
},
}
}
// extractErrorString pulls the "error" field from a JSON object if
// present, returning it as a string. Returns "" if the field is
// missing or the input is not an object.
func extractErrorString(raw json.RawMessage) string {
var fields map[string]json.RawMessage
if err := json.Unmarshal(raw, &fields); err != nil {
return ""
}
errField, ok := fields["error"]
if !ok {
return ""
}
var s string
if err := json.Unmarshal(errField, &s); err != nil {
return ""
}
return strings.TrimSpace(s)
}
// ToMessageParts converts fantasy content blocks into message parts.
func ToMessageParts(content []fantasy.Content) []fantasy.MessagePart {
parts := make([]fantasy.MessagePart, 0, len(content))
for _, block := range content {
switch value := block.(type) {
case fantasy.TextContent:
parts = append(parts, fantasy.TextPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.TextContent:
parts = append(parts, fantasy.TextPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.ReasoningContent:
parts = append(parts, fantasy.ReasoningPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.ReasoningContent:
parts = append(parts, fantasy.ReasoningPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.ToolCallContent:
parts = append(parts, fantasy.ToolCallPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
ToolName: value.ToolName,
Input: value.Input,
ProviderExecuted: value.ProviderExecuted,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.ToolCallContent:
parts = append(parts, fantasy.ToolCallPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
ToolName: value.ToolName,
Input: value.Input,
ProviderExecuted: value.ProviderExecuted,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.FileContent:
parts = append(parts, fantasy.FilePart{
Data: value.Data,
MediaType: value.MediaType,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.FileContent:
parts = append(parts, fantasy.FilePart{
Data: value.Data,
MediaType: value.MediaType,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.ToolResultContent:
parts = append(parts, fantasy.ToolResultPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
Output: value.Result,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.ToolResultContent:
parts = append(parts, fantasy.ToolResultPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
Output: value.Result,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
}
}
return parts
}
func normalizeAssistantToolCallInputs(
parts []fantasy.MessagePart,
) []fantasy.MessagePart {
normalized := make([]fantasy.MessagePart, 0, len(parts))
for _, part := range parts {
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
if !ok {
normalized = append(normalized, part)
continue
}
toolCall.Input = normalizeToolCallInput(toolCall.Input)
normalized = append(normalized, toolCall)
}
return normalized
}
// normalizeToolCallInput guarantees tool call input is a JSON object string.
// Anthropic drops assistant tool calls with malformed input, which can leave
// following tool results orphaned.
func normalizeToolCallInput(input string) string {
input = strings.TrimSpace(input)
if input == "" {
return "{}"
}
var object map[string]any
if err := json.Unmarshal([]byte(input), &object); err != nil || object == nil {
return "{}"
}
return input
}
// ExtractToolCalls returns all tool call parts as content blocks.
func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent {
toolCalls := make([]fantasy.ToolCallContent, 0, len(parts))
for _, part := range parts {
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
if !ok {
continue
}
toolCalls = append(toolCalls, fantasy.ToolCallContent{
ToolCallID: toolCall.ToolCallID,
ToolName: toolCall.ToolName,
Input: toolCall.Input,
ProviderExecuted: toolCall.ProviderExecuted,
})
}
return toolCalls
}
// MarshalContent encodes message content blocks for persistence.
func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
if len(blocks) == 0 {
return pqtype.NullRawMessage{}, nil
}
encodedBlocks := make([]json.RawMessage, 0, len(blocks))
for i, block := range blocks {
encoded, err := marshalContentBlock(block)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf(
"encode content block %d: %w",
i,
err,
)
}
encodedBlocks = append(encodedBlocks, encoded)
}
data, err := json.Marshal(encodedBlocks)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf("encode content blocks: %w", err)
}
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
}
// MarshalToolResult encodes a single tool result for persistence as
// an opaque JSON blob. The stored shape is
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isError bool) (pqtype.NullRawMessage, error) {
row := toolResultRaw{
ToolCallID: toolCallID,
ToolName: toolName,
Result: result,
IsError: isError,
}
data, err := json.Marshal([]toolResultRaw{row})
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf("encode tool result: %w", err)
}
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
}
// MarshalToolResultContent encodes a fantasy tool result content
// block for persistence. It extracts the raw fields and delegates
// to MarshalToolResult.
func MarshalToolResultContent(content fantasy.ToolResultContent) (pqtype.NullRawMessage, error) {
var result json.RawMessage
var isError bool
switch output := content.Result.(type) {
case fantasy.ToolResultOutputContentError:
isError = true
if output.Error != nil {
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
} else {
result = []byte(`{"error":""}`)
}
case fantasy.ToolResultOutputContentText:
result = json.RawMessage(output.Text)
if !json.Valid(result) {
result, _ = json.Marshal(map[string]any{"output": output.Text})
}
case fantasy.ToolResultOutputContentMedia:
result, _ = json.Marshal(map[string]any{
"data": output.Data,
"mime_type": output.MediaType,
"text": output.Text,
})
default:
result = []byte(`{}`)
}
return MarshalToolResult(content.ToolCallID, content.ToolName, result, isError)
}
// PartFromContent converts fantasy content into a SDK chat message part.
func PartFromContent(block fantasy.Content) codersdk.ChatMessagePart {
switch value := block.(type) {
case fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case *fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
Title: reasoningSummaryTitle(value.ProviderMetadata),
}
case *fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
Title: reasoningSummaryTitle(value.ProviderMetadata),
}
case fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case *fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case *fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case *fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case fantasy.ToolResultContent:
return toolResultContentToPart(value)
case *fantasy.ToolResultContent:
return toolResultContentToPart(*value)
default:
return codersdk.ChatMessagePart{}
}
}
// ToolResultToPart converts a tool call ID, raw result, and error
// flag into a ChatMessagePart. This is the minimal conversion used
// both during streaming and when reading from the database.
func ToolResultToPart(toolCallID, toolName string, result json.RawMessage, isError bool) codersdk.ChatMessagePart {
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: toolCallID,
ToolName: toolName,
Result: result,
IsError: isError,
}
}
// toolResultContentToPart converts a fantasy ToolResultContent
// directly into a ChatMessagePart without an intermediate struct.
func toolResultContentToPart(content fantasy.ToolResultContent) codersdk.ChatMessagePart {
var result json.RawMessage
var isError bool
switch output := content.Result.(type) {
case fantasy.ToolResultOutputContentError:
isError = true
if output.Error != nil {
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
} else {
result = []byte(`{"error":""}`)
}
case fantasy.ToolResultOutputContentText:
result = json.RawMessage(output.Text)
// Ensure valid JSON; wrap in an object if not.
if !json.Valid(result) {
result, _ = json.Marshal(map[string]any{"output": output.Text})
}
case fantasy.ToolResultOutputContentMedia:
result, _ = json.Marshal(map[string]any{
"data": output.Data,
"mime_type": output.MediaType,
"text": output.Text,
})
default:
result = []byte(`{}`)
}
return ToolResultToPart(content.ToolCallID, content.ToolName, result, isError)
}
// ReasoningTitleFromFirstLine extracts a compact markdown title.
func ReasoningTitleFromFirstLine(text string) string {
text = strings.TrimSpace(text)
if text == "" {
return ""
}
firstLine := text
if idx := strings.IndexAny(firstLine, "\r\n"); idx >= 0 {
firstLine = firstLine[:idx]
}
firstLine = strings.TrimSpace(firstLine)
if firstLine == "" || !strings.HasPrefix(firstLine, "**") {
return ""
}
rest := firstLine[2:]
end := strings.Index(rest, "**")
if end < 0 {
return ""
}
title := strings.TrimSpace(rest[:end])
if title == "" {
return ""
}
// Require the first line to be exactly "**title**" (ignoring
// surrounding whitespace) so providers without this format don't
// accidentally emit a title.
if strings.TrimSpace(rest[end+2:]) != "" {
return ""
}
return compactReasoningSummaryTitle(title)
}
func injectMissingToolResults(prompt []fantasy.Message) []fantasy.Message {
result := make([]fantasy.Message, 0, len(prompt))
for i := 0; i < len(prompt); i++ {
msg := prompt[i]
result = append(result, msg)
if msg.Role != fantasy.MessageRoleAssistant {
continue
}
toolCalls := ExtractToolCalls(msg.Content)
if len(toolCalls) == 0 {
continue
}
// Collect the tool call IDs that have results in the
// following tool message(s).
answered := make(map[string]struct{})
j := i + 1
for ; j < len(prompt); j++ {
if prompt[j].Role != fantasy.MessageRoleTool {
break
}
for _, part := range prompt[j].Content {
tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
if !ok {
continue
}
answered[tr.ToolCallID] = struct{}{}
}
}
if i+1 < j {
// Preserve persisted tool result ordering and inject any
// synthetic results after the existing contiguous tool messages.
result = append(result, prompt[i+1:j]...)
i = j - 1
}
// Build synthetic results for any unanswered tool calls.
var missing []fantasy.MessagePart
for _, tc := range toolCalls {
if _, ok := answered[tc.ToolCallID]; !ok {
missing = append(missing, fantasy.ToolResultPart{
ToolCallID: tc.ToolCallID,
Output: fantasy.ToolResultOutputContentError{
Error: xerrors.New("tool call was interrupted and did not receive a result"),
},
})
}
}
if len(missing) > 0 {
result = append(result, fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: missing,
})
}
}
return result
}
func injectMissingToolUses(
prompt []fantasy.Message,
toolNameByCallID map[string]string,
) []fantasy.Message {
result := make([]fantasy.Message, 0, len(prompt))
for _, msg := range prompt {
if msg.Role != fantasy.MessageRoleTool {
result = append(result, msg)
continue
}
toolResults := make([]fantasy.ToolResultPart, 0, len(msg.Content))
for _, part := range msg.Content {
toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
if !ok {
continue
}
toolResults = append(toolResults, toolResult)
}
if len(toolResults) == 0 {
result = append(result, msg)
continue
}
// Walk backwards through the result to find the nearest
// preceding assistant message (skipping over other tool
// messages that belong to the same batch of results).
answeredByPrevious := make(map[string]struct{})
for k := len(result) - 1; k >= 0; k-- {
if result[k].Role == fantasy.MessageRoleAssistant {
for _, toolCall := range ExtractToolCalls(result[k].Content) {
toolCallID := sanitizeToolCallID(toolCall.ToolCallID)
if toolCallID == "" {
continue
}
answeredByPrevious[toolCallID] = struct{}{}
}
break
}
if result[k].Role != fantasy.MessageRoleTool {
break
}
}
matchingResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
orphanResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
for _, toolResult := range toolResults {
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
if _, ok := answeredByPrevious[toolCallID]; ok {
matchingResults = append(matchingResults, toolResult)
continue
}
orphanResults = append(orphanResults, toolResult)
}
if len(orphanResults) == 0 {
result = append(result, msg)
continue
}
syntheticToolUse := syntheticToolUseMessage(
orphanResults,
toolNameByCallID,
)
if len(syntheticToolUse.Content) == 0 {
result = append(result, msg)
continue
}
if len(matchingResults) > 0 {
result = append(result, toolMessageFromToolResultParts(matchingResults))
}
result = append(result, syntheticToolUse)
result = append(result, toolMessageFromToolResultParts(orphanResults))
}
return result
}
func toolMessageFromToolResultParts(results []fantasy.ToolResultPart) fantasy.Message {
parts := make([]fantasy.MessagePart, 0, len(results))
for _, result := range results {
parts = append(parts, result)
}
return fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: parts,
}
}
func syntheticToolUseMessage(
toolResults []fantasy.ToolResultPart,
toolNameByCallID map[string]string,
) fantasy.Message {
parts := make([]fantasy.MessagePart, 0, len(toolResults))
seen := make(map[string]struct{}, len(toolResults))
for _, toolResult := range toolResults {
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
if toolCallID == "" {
continue
}
if _, ok := seen[toolCallID]; ok {
continue
}
toolName := strings.TrimSpace(toolNameByCallID[toolCallID])
if toolName == "" {
continue
}
seen[toolCallID] = struct{}{}
parts = append(parts, fantasy.ToolCallPart{
ToolCallID: toolCallID,
ToolName: toolName,
Input: "{}",
})
}
return fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: parts,
}
}
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return "", nil
}
var content string
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
return "", xerrors.Errorf("parse system message content: %w", err)
}
return content, nil
}
func sanitizeToolCallID(id string) string {
if id == "" {
return ""
}
return toolCallIDSanitizer.ReplaceAllString(id, "_")
}
func marshalContentBlock(block fantasy.Content) (json.RawMessage, error) {
encoded, err := json.Marshal(block)
if err != nil {
return nil, err
}
title, ok := reasoningTitleFromContent(block)
if !ok || title == "" {
return encoded, nil
}
var envelope struct {
Type string `json:"type"`
Data map[string]any `json:"data"`
}
if err := json.Unmarshal(encoded, &envelope); err != nil {
return nil, err
}
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
return encoded, nil
}
if envelope.Data == nil {
envelope.Data = map[string]any{}
}
envelope.Data["title"] = title
encodedWithTitle, err := json.Marshal(envelope)
if err != nil {
return nil, err
}
return encodedWithTitle, nil
}
func reasoningTitleFromContent(block fantasy.Content) (string, bool) {
switch value := block.(type) {
case fantasy.ReasoningContent:
return ReasoningTitleFromFirstLine(value.Text), true
case *fantasy.ReasoningContent:
if value == nil {
return "", false
}
return ReasoningTitleFromFirstLine(value.Text), true
default:
return "", false
}
}
func reasoningSummaryTitle(metadata fantasy.ProviderMetadata) string {
if len(metadata) == 0 {
return ""
}
reasoningMetadata := fantasyopenai.GetReasoningMetadata(
fantasy.ProviderOptions(metadata),
)
if reasoningMetadata == nil {
return ""
}
for _, summary := range reasoningMetadata.Summary {
if title := compactReasoningSummaryTitle(summary); title != "" {
return title
}
}
return ""
}
func compactReasoningSummaryTitle(summary string) string {
const maxWords = 8
const maxRunes = 80
summary = strings.TrimSpace(summary)
if summary == "" {
return ""
}
summary = strings.Trim(summary, "\"'`")
summary = reasoningSummaryHeadline(summary)
words := strings.Fields(summary)
if len(words) == 0 {
return ""
}
truncated := false
if len(words) > maxWords {
words = words[:maxWords]
truncated = true
}
title := strings.Join(words, " ")
if truncated {
title += "…"
}
return truncateRunes(title, maxRunes)
}
func reasoningSummaryHeadline(summary string) string {
summary = strings.TrimSpace(summary)
if summary == "" {
return ""
}
// OpenAI summary_text may be markdown like:
// "**Title**\n\nLonger explanation ...".
// Keep only the heading segment for UI titles.
if idx := strings.Index(summary, "\n\n"); idx >= 0 {
summary = summary[:idx]
}
if idx := strings.IndexAny(summary, "\r\n"); idx >= 0 {
summary = summary[:idx]
}
summary = strings.TrimSpace(summary)
if summary == "" {
return ""
}
if strings.HasPrefix(summary, "**") {
rest := summary[2:]
if end := strings.Index(rest, "**"); end >= 0 {
bold := strings.TrimSpace(rest[:end])
if bold != "" {
summary = bold
}
}
}
return strings.TrimSpace(strings.Trim(summary, "\"'`"))
}
func truncateRunes(value string, maxLen int) string {
if maxLen <= 0 {
return ""
}
runes := []rune(value)
if len(runes) <= maxLen {
return value
}
return string(runes[:maxLen])
}
@@ -0,0 +1,91 @@
package chatprompt_test
import (
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
)
func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "empty input",
input: "",
expected: "{}",
},
{
name: "invalid json",
input: "{\"command\":",
expected: "{}",
},
{
name: "non-object json",
input: "[]",
expected: "{}",
},
{
name: "valid object json",
input: "{\"command\":\"ls\"}",
expected: "{\"command\":\"ls\"}",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assistantContent, err := chatprompt.MarshalContent([]fantasy.Content{
fantasy.ToolCallContent{
ToolCallID: "toolu_01C4PqN6F2493pi7Ebag8Vg7",
ToolName: "execute",
Input: tc.input,
},
})
require.NoError(t, err)
toolContent, err := chatprompt.MarshalToolResult(
"toolu_01C4PqN6F2493pi7Ebag8Vg7",
"execute",
json.RawMessage(`{"error":"tool call was interrupted before it produced a result"}`),
true,
)
require.NoError(t, err)
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
{
Role: string(fantasy.MessageRoleAssistant),
Visibility: database.ChatMessageVisibilityBoth,
Content: assistantContent,
},
{
Role: string(fantasy.MessageRoleTool),
Visibility: database.ChatMessageVisibilityBoth,
Content: toolContent,
},
})
require.NoError(t, err)
require.Len(t, prompt, 2)
require.Equal(t, fantasy.MessageRoleAssistant, prompt[0].Role)
toolCalls := chatprompt.ExtractToolCalls(prompt[0].Content)
require.Len(t, toolCalls, 1)
require.Equal(t, tc.expected, toolCalls[0].Input)
require.Equal(t, "execute", toolCalls[0].ToolName)
require.Equal(t, "toolu_01C4PqN6F2493pi7Ebag8Vg7", toolCalls[0].ToolCallID)
require.Equal(t, fantasy.MessageRoleTool, prompt[1].Role)
})
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,191 @@
package chatprovider_test
import (
"testing"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
fantasyopenai "charm.land/fantasy/providers/openai"
fantasyopenrouter "charm.land/fantasy/providers/openrouter"
fantasyvercel "charm.land/fantasy/providers/vercel"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/codersdk"
)
func TestReasoningEffortFromChat(t *testing.T) {
t.Parallel()
tests := []struct {
name string
provider string
input *string
want *string
}{
{
name: "OpenAICaseInsensitive",
provider: "openai",
input: stringPtr(" HIGH "),
want: stringPtr(string(fantasyopenai.ReasoningEffortHigh)),
},
{
name: "AnthropicEffort",
provider: "anthropic",
input: stringPtr("max"),
want: stringPtr(string(fantasyanthropic.EffortMax)),
},
{
name: "OpenRouterEffort",
provider: "openrouter",
input: stringPtr("medium"),
want: stringPtr(string(fantasyopenrouter.ReasoningEffortMedium)),
},
{
name: "VercelEffort",
provider: "vercel",
input: stringPtr("xhigh"),
want: stringPtr(string(fantasyvercel.ReasoningEffortXHigh)),
},
{
name: "InvalidEffortReturnsNil",
provider: "openai",
input: stringPtr("unknown"),
want: nil,
},
{
name: "UnsupportedProviderReturnsNil",
provider: "bedrock",
input: stringPtr("high"),
want: nil,
},
{
name: "NilInputReturnsNil",
provider: "openai",
input: nil,
want: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := chatprovider.ReasoningEffortFromChat(tt.provider, tt.input)
require.Equal(t, tt.want, got)
})
}
}
func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
t.Parallel()
options := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(true),
},
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"openai"},
},
},
}
defaults := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(false),
Exclude: boolPtr(true),
MaxTokens: int64Ptr(123),
Effort: stringPtr("high"),
},
IncludeUsage: boolPtr(true),
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"anthropic"},
AllowFallbacks: boolPtr(true),
RequireParameters: boolPtr(false),
DataCollection: stringPtr("allow"),
Only: []string{"openai"},
Ignore: []string{"foo"},
Quantizations: []string{"int8"},
Sort: stringPtr("latency"),
},
},
}
chatprovider.MergeMissingProviderOptions(&options, defaults)
require.NotNil(t, options)
require.NotNil(t, options.OpenRouter)
require.NotNil(t, options.OpenRouter.Reasoning)
require.True(t, *options.OpenRouter.Reasoning.Enabled)
require.Equal(t, true, *options.OpenRouter.Reasoning.Exclude)
require.EqualValues(t, 123, *options.OpenRouter.Reasoning.MaxTokens)
require.Equal(t, "high", *options.OpenRouter.Reasoning.Effort)
require.NotNil(t, options.OpenRouter.IncludeUsage)
require.True(t, *options.OpenRouter.IncludeUsage)
require.NotNil(t, options.OpenRouter.Provider)
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Order)
require.NotNil(t, options.OpenRouter.Provider.AllowFallbacks)
require.True(t, *options.OpenRouter.Provider.AllowFallbacks)
require.NotNil(t, options.OpenRouter.Provider.RequireParameters)
require.False(t, *options.OpenRouter.Provider.RequireParameters)
require.Equal(t, "allow", *options.OpenRouter.Provider.DataCollection)
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Only)
require.Equal(t, []string{"foo"}, options.OpenRouter.Provider.Ignore)
require.Equal(t, []string{"int8"}, options.OpenRouter.Provider.Quantizations)
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
}
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
t.Parallel()
dst := codersdk.ChatModelCallConfig{
Temperature: float64Ptr(0.2),
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("alice"),
},
},
}
defaults := codersdk.ChatModelCallConfig{
MaxOutputTokens: int64Ptr(512),
Temperature: float64Ptr(0.9),
TopP: float64Ptr(0.8),
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("bob"),
ReasoningEffort: stringPtr("medium"),
},
},
}
chatprovider.MergeMissingCallConfig(&dst, defaults)
require.NotNil(t, dst.MaxOutputTokens)
require.EqualValues(t, 512, *dst.MaxOutputTokens)
require.NotNil(t, dst.Temperature)
require.Equal(t, 0.2, *dst.Temperature)
require.NotNil(t, dst.TopP)
require.Equal(t, 0.8, *dst.TopP)
require.NotNil(t, dst.ProviderOptions)
require.NotNil(t, dst.ProviderOptions.OpenAI)
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
}
func stringPtr(value string) *string {
return &value
}
func boolPtr(value bool) *bool {
return &value
}
func int64Ptr(value int64) *int64 {
return &value
}
func float64Ptr(value float64) *float64 {
return &value
}
+175
View File
@@ -0,0 +1,175 @@
// Package chatretry provides retry logic for transient LLM provider
// errors. It classifies errors as retryable or permanent and
// implements exponential backoff matching the behavior of coder/mux.
package chatretry
import (
"context"
"errors"
"strings"
"time"
)
const (
// InitialDelay is the backoff duration for the first retry
// attempt.
InitialDelay = 1 * time.Second
// MaxDelay is the upper bound for the exponential backoff
// duration. Matches the cap used in coder/mux.
MaxDelay = 60 * time.Second
)
// nonRetryablePatterns are substrings that indicate a permanent error
// which should not be retried. These are checked first so that
// ambiguous messages (e.g. "bad request: rate limit") are correctly
// classified as non-retryable.
var nonRetryablePatterns = []string{
"context canceled",
"context deadline exceeded",
"authentication",
"unauthorized",
"forbidden",
"invalid api key",
"invalid_api_key",
"invalid model",
"model not found",
"model_not_found",
"context length exceeded",
"context_exceeded",
"maximum context length",
"quota",
"billing",
}
// retryablePatterns are substrings that indicate a transient error
// worth retrying.
var retryablePatterns = []string{
"overloaded",
"rate limit",
"rate_limit",
"too many requests",
"server error",
"status 500",
"status 502",
"status 503",
"status 529",
"connection reset",
"connection refused",
"eof",
"broken pipe",
"timeout",
"unavailable",
"service unavailable",
}
// IsRetryable determines whether an error from an LLM provider is
// transient and worth retrying. It inspects the error message and
// any wrapped HTTP status codes for known retryable patterns.
func IsRetryable(err error) bool {
if err == nil {
return false
}
// context.Canceled is always non-retryable regardless of
// wrapping.
if errors.Is(err, context.Canceled) {
return false
}
lower := strings.ToLower(err.Error())
// Check non-retryable patterns first so they take precedence.
for _, p := range nonRetryablePatterns {
if strings.Contains(lower, p) {
return false
}
}
for _, p := range retryablePatterns {
if strings.Contains(lower, p) {
return true
}
}
return false
}
// StatusCodeRetryable returns true for HTTP status codes that
// indicate a transient failure worth retrying.
func StatusCodeRetryable(code int) bool {
switch code {
case 429, 500, 502, 503, 529:
return true
default:
return false
}
}
// Delay returns the backoff duration for the given 0-indexed attempt.
// Uses exponential backoff: min(InitialDelay * 2^attempt, MaxDelay).
// Matches the backoff curve used in coder/mux.
func Delay(attempt int) time.Duration {
d := InitialDelay
for range attempt {
d *= 2
if d >= MaxDelay {
return MaxDelay
}
}
return d
}
// RetryFn is the function to retry. It receives a context and returns
// an error. The context may be a child of the original with adjusted
// deadlines for individual attempts.
type RetryFn func(ctx context.Context) error
// OnRetryFn is called before each retry attempt with the attempt
// number (1-indexed), the error that triggered the retry, and the
// delay before the next attempt.
type OnRetryFn func(attempt int, err error, delay time.Duration)
// Retry calls fn repeatedly until it succeeds, returns a
// non-retryable error, or ctx is canceled. There is no max attempt
// limit — retries continue indefinitely with exponential backoff
// (capped at 60s), matching the behavior of coder/mux.
//
// The onRetry callback (if non-nil) is called before each retry
// attempt, giving the caller a chance to reset state, log, or
// publish status events.
func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
var attempt int
for {
err := fn(ctx)
if err == nil {
return nil
}
if !IsRetryable(err) {
return err
}
// If the caller's context is already done, return the
// context error so cancellation propagates cleanly.
if ctx.Err() != nil {
return ctx.Err()
}
delay := Delay(attempt)
if onRetry != nil {
onRetry(attempt+1, err, delay)
}
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-timer.C:
}
attempt++
}
}
+452
View File
@@ -0,0 +1,452 @@
package chatretry_test
import (
"context"
"errors"
"fmt"
"sync/atomic"
"testing"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chatretry"
)
func TestIsRetryable(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
retryable bool
}{
// Retryable errors.
{
name: "Overloaded",
err: xerrors.New("model is overloaded, please try again"),
retryable: true,
},
{
name: "RateLimit",
err: xerrors.New("rate limit exceeded"),
retryable: true,
},
{
name: "RateLimitUnderscore",
err: xerrors.New("rate_limit: too many requests"),
retryable: true,
},
{
name: "TooManyRequests",
err: xerrors.New("too many requests"),
retryable: true,
},
{
name: "HTTP429InMessage",
err: xerrors.New("received status 429 from upstream"),
retryable: false, // "429" alone is not a pattern; needs matching text.
},
{
name: "HTTP529InMessage",
err: xerrors.New("received status 529 from upstream"),
retryable: true,
},
{
name: "ServerError500",
err: xerrors.New("status 500: internal server error"),
retryable: true,
},
{
name: "ServerErrorGeneric",
err: xerrors.New("server error"),
retryable: true,
},
{
name: "ConnectionReset",
err: xerrors.New("read tcp: connection reset by peer"),
retryable: true,
},
{
name: "ConnectionRefused",
err: xerrors.New("dial tcp: connection refused"),
retryable: true,
},
{
name: "EOF",
err: xerrors.New("unexpected EOF"),
retryable: true,
},
{
name: "BrokenPipe",
err: xerrors.New("write: broken pipe"),
retryable: true,
},
{
name: "NetworkTimeout",
err: xerrors.New("i/o timeout"),
retryable: true,
},
{
name: "ServiceUnavailable",
err: xerrors.New("service unavailable"),
retryable: true,
},
{
name: "Unavailable",
err: xerrors.New("the service is currently unavailable"),
retryable: true,
},
{
name: "Status502",
err: xerrors.New("status 502: bad gateway"),
retryable: true,
},
{
name: "Status503",
err: xerrors.New("status 503"),
retryable: true,
},
// Non-retryable errors.
{
name: "Nil",
err: nil,
retryable: false,
},
{
name: "ContextCanceled",
err: context.Canceled,
retryable: false,
},
{
name: "ContextCanceledWrapped",
err: xerrors.Errorf("operation failed: %w", context.Canceled),
retryable: false,
},
{
name: "ContextCanceledMessage",
err: xerrors.New("context canceled"),
retryable: false,
},
{
name: "ContextDeadlineExceeded",
err: xerrors.New("context deadline exceeded"),
retryable: false,
},
{
name: "Authentication",
err: xerrors.New("authentication failed"),
retryable: false,
},
{
name: "Unauthorized",
err: xerrors.New("401 Unauthorized"),
retryable: false,
},
{
name: "Forbidden",
err: xerrors.New("403 Forbidden"),
retryable: false,
},
{
name: "InvalidAPIKey",
err: xerrors.New("invalid api key"),
retryable: false,
},
{
name: "InvalidAPIKeyUnderscore",
err: xerrors.New("invalid_api_key"),
retryable: false,
},
{
name: "InvalidModel",
err: xerrors.New("invalid model: gpt-5-turbo"),
retryable: false,
},
{
name: "ModelNotFound",
err: xerrors.New("model not found"),
retryable: false,
},
{
name: "ModelNotFoundUnderscore",
err: xerrors.New("model_not_found"),
retryable: false,
},
{
name: "ContextLengthExceeded",
err: xerrors.New("context length exceeded"),
retryable: false,
},
{
name: "ContextExceededUnderscore",
err: xerrors.New("context_exceeded"),
retryable: false,
},
{
name: "MaximumContextLength",
err: xerrors.New("maximum context length"),
retryable: false,
},
{
name: "QuotaExceeded",
err: xerrors.New("quota exceeded"),
retryable: false,
},
{
name: "BillingError",
err: xerrors.New("billing issue: payment required"),
retryable: false,
},
// Wrapped errors preserve retryability.
{
name: "WrappedRetryable",
err: xerrors.Errorf("provider call failed: %w", xerrors.New("service unavailable")),
retryable: true,
},
{
name: "WrappedNonRetryable",
err: xerrors.Errorf("provider call failed: %w", xerrors.New("invalid api key")),
retryable: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := chatretry.IsRetryable(tt.err)
if got != tt.retryable {
t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.retryable)
}
})
}
}
func TestStatusCodeRetryable(t *testing.T) {
t.Parallel()
tests := []struct {
code int
retryable bool
}{
{429, true},
{500, true},
{502, true},
{503, true},
{529, true},
{200, false},
{400, false},
{401, false},
{403, false},
{404, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Status%d", tt.code), func(t *testing.T) {
t.Parallel()
got := chatretry.StatusCodeRetryable(tt.code)
if got != tt.retryable {
t.Errorf("StatusCodeRetryable(%d) = %v, want %v", tt.code, got, tt.retryable)
}
})
}
}
func TestDelay(t *testing.T) {
t.Parallel()
tests := []struct {
attempt int
want time.Duration
}{
{0, 1 * time.Second},
{1, 2 * time.Second},
{2, 4 * time.Second},
{3, 8 * time.Second},
{4, 16 * time.Second},
{5, 32 * time.Second},
{6, 60 * time.Second}, // Capped at MaxDelay.
{10, 60 * time.Second}, // Still capped.
{100, 60 * time.Second},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Attempt%d", tt.attempt), func(t *testing.T) {
t.Parallel()
got := chatretry.Delay(tt.attempt)
if got != tt.want {
t.Errorf("Delay(%d) = %v, want %v", tt.attempt, got, tt.want)
}
})
}
}
func TestRetry_SuccessOnFirstTry(t *testing.T) {
t.Parallel()
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
return nil
}, nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if calls != 1 {
t.Fatalf("expected fn called once, got %d", calls)
}
}
func TestRetry_TransientThenSuccess(t *testing.T) {
t.Parallel()
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
if calls == 1 {
return xerrors.New("service unavailable")
}
return nil
}, nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if calls != 2 {
t.Fatalf("expected fn called twice, got %d", calls)
}
}
func TestRetry_MultipleTransientThenSuccess(t *testing.T) {
t.Parallel()
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
if calls <= 3 {
return xerrors.New("overloaded")
}
return nil
}, nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if calls != 4 {
t.Fatalf("expected fn called 4 times, got %d", calls)
}
}
func TestRetry_NonRetryableError(t *testing.T) {
t.Parallel()
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
return xerrors.New("invalid api key")
}, nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if err.Error() != "invalid api key" {
t.Fatalf("expected 'invalid api key', got %q", err.Error())
}
if calls != 1 {
t.Fatalf("expected fn called once, got %d", calls)
}
}
func TestRetry_ContextCanceledDuringWait(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
calls := 0
err := chatretry.Retry(ctx, func(_ context.Context) error {
calls++
// Cancel after the first retryable error so the wait
// select picks up the cancellation.
if calls == 1 {
cancel()
}
return xerrors.New("overloaded")
}, nil)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
}
func TestRetry_ContextCanceledDuringFn(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
err := chatretry.Retry(ctx, func(_ context.Context) error {
cancel()
// Return a retryable error; the loop should detect that
// ctx is done and return the context error.
return xerrors.New("overloaded")
}, nil)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
}
func TestRetry_OnRetryCalledWithCorrectArgs(t *testing.T) {
t.Parallel()
type retryRecord struct {
attempt int
errMsg string
delay time.Duration
}
var records []retryRecord
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
if calls <= 2 {
return xerrors.New("rate limit exceeded")
}
return nil
}, func(attempt int, err error, delay time.Duration) {
records = append(records, retryRecord{
attempt: attempt,
errMsg: err.Error(),
delay: delay,
})
})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if len(records) != 2 {
t.Fatalf("expected 2 onRetry calls, got %d", len(records))
}
if records[0].attempt != 1 {
t.Errorf("first onRetry attempt = %d, want 1", records[0].attempt)
}
if records[1].attempt != 2 {
t.Errorf("second onRetry attempt = %d, want 2", records[1].attempt)
}
if records[0].errMsg != "rate limit exceeded" {
t.Errorf("first onRetry error = %q, want 'rate limit exceeded'", records[0].errMsg)
}
}
func TestRetry_OnRetryNilDoesNotPanic(t *testing.T) {
t.Parallel()
var calls atomic.Int32
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
if calls.Add(1) == 1 {
return xerrors.New("overloaded")
}
return nil
}, nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
}
+409
View File
@@ -0,0 +1,409 @@
package chattest
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/google/uuid"
)
// AnthropicHandler handles Anthropic API requests and returns a response.
type AnthropicHandler func(req *AnthropicRequest) AnthropicResponse
// AnthropicResponse represents a response to an Anthropic request.
// Either StreamingChunks or Response should be set, not both.
type AnthropicResponse struct {
StreamingChunks <-chan AnthropicChunk
Response *AnthropicMessage
Error *ErrorResponse // If set, server returns this HTTP error instead of streaming/JSON.
}
// AnthropicRequest represents an Anthropic messages request.
type AnthropicRequest struct {
*http.Request // Embed http.Request
Model string `json:"model"`
Messages []AnthropicRequestMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
// TODO: encoding/json ignores inline tags. Add custom UnmarshalJSON to capture unknown keys.
Options map[string]interface{} `json:",inline"` //nolint:revive
}
// AnthropicRequestMessage represents a message in an Anthropic request.
// Content may be either a string or a structured content array.
type AnthropicRequestMessage struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
}
// AnthropicMessage represents a message in an Anthropic response.
type AnthropicMessage struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Role string `json:"role"`
Content string `json:"content,omitempty"`
Model string `json:"model,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
Usage AnthropicUsage `json:"usage,omitempty"`
}
// AnthropicUsage represents usage information in an Anthropic response.
type AnthropicUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// AnthropicChunk represents a streaming chunk from Anthropic.
type AnthropicChunk struct {
Type string `json:"type"`
Index int `json:"index,omitempty"`
Message AnthropicChunkMessage `json:"message,omitempty"`
ContentBlock AnthropicContentBlock `json:"content_block,omitempty"`
Delta AnthropicDeltaBlock `json:"delta,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence *string `json:"stop_sequence,omitempty"`
Usage AnthropicUsage `json:"usage,omitempty"`
}
// AnthropicChunkMessage represents message metadata in a chunk.
type AnthropicChunkMessage struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
}
// AnthropicContentBlock represents a content block in a chunk.
type AnthropicContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
}
// AnthropicDeltaBlock represents a delta block in a chunk.
type AnthropicDeltaBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
}
// anthropicServer is a test server that mocks the Anthropic API.
type anthropicServer struct {
mu sync.Mutex
server *httptest.Server
handler AnthropicHandler
request *AnthropicRequest
}
// NewAnthropic creates a new Anthropic test server with a handler function.
// The handler is called for each request and should return either a streaming
// response (via channel) or a non-streaming response.
// Returns the base URL of the server.
func NewAnthropic(t testing.TB, handler AnthropicHandler) string {
t.Helper()
s := &anthropicServer{
handler: handler,
}
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/messages", s.handleMessages)
s.server = httptest.NewServer(mux)
t.Cleanup(func() {
s.server.Close()
})
return s.server.URL
}
func (s *anthropicServer) handleMessages(w http.ResponseWriter, r *http.Request) {
var req AnthropicRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// Return a more detailed error for debugging
http.Error(w, fmt.Sprintf("decode request: %v", err), http.StatusBadRequest)
return
}
req.Request = r // Embed the original http.Request
s.mu.Lock()
s.request = &req
s.mu.Unlock()
resp := s.handler(&req)
s.writeResponse(w, &req, resp)
}
func (s *anthropicServer) writeResponse(w http.ResponseWriter, req *AnthropicRequest, resp AnthropicResponse) {
if resp.Error != nil {
writeErrorResponse(w, resp.Error)
return
}
hasStreaming := resp.StreamingChunks != nil
hasNonStreaming := resp.Response != nil
switch {
case hasStreaming && hasNonStreaming:
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
return
case !hasStreaming && !hasNonStreaming:
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
return
case req.Stream && !hasStreaming:
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
return
case !req.Stream && !hasNonStreaming:
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
s.writeStreamingResponse(w, resp.StreamingChunks)
default:
s.writeNonStreamingResponse(w, resp.Response)
}
}
func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <-chan AnthropicChunk) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("anthropic-version", "2023-06-01")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
for chunk := range chunks {
chunkData := make(map[string]interface{})
chunkData["type"] = chunk.Type
switch chunk.Type {
case "message_start":
chunkData["message"] = chunk.Message
case "content_block_start":
chunkData["index"] = chunk.Index
chunkData["content_block"] = chunk.ContentBlock
case "content_block_delta":
chunkData["index"] = chunk.Index
chunkData["delta"] = chunk.Delta
case "content_block_stop":
chunkData["index"] = chunk.Index
case "message_delta":
chunkData["delta"] = map[string]interface{}{
"stop_reason": chunk.StopReason,
"stop_sequence": chunk.StopSequence,
}
chunkData["usage"] = chunk.Usage
case "message_stop":
// No additional fields
}
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
return
}
// Send both event and data lines to match Anthropic API format
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", chunk.Type, chunkBytes); err != nil {
return
}
flusher.Flush()
}
}
func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp *AnthropicMessage) {
_ = s // receiver unused but kept for consistency
response := map[string]interface{}{
"id": resp.ID,
"type": resp.Type,
"role": resp.Role,
"model": resp.Model,
"content": []map[string]interface{}{
{
"type": "text",
"text": resp.Content,
},
},
"stop_reason": resp.StopReason,
"usage": resp.Usage,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("anthropic-version", "2023-06-01")
_ = json.NewEncoder(w).Encode(response)
}
// AnthropicStreamingResponse creates a streaming response from chunks.
func AnthropicStreamingResponse(chunks ...AnthropicChunk) AnthropicResponse {
ch := make(chan AnthropicChunk, len(chunks))
go func() {
for _, chunk := range chunks {
ch <- chunk
}
close(ch)
}()
return AnthropicResponse{StreamingChunks: ch}
}
// AnthropicNonStreamingResponse creates a non-streaming response with the given text.
func AnthropicNonStreamingResponse(text string) AnthropicResponse {
return AnthropicResponse{
Response: &AnthropicMessage{
ID: fmt.Sprintf("msg-%s", uuid.New().String()[:8]),
Type: "message",
Role: "assistant",
Content: text,
Model: "claude-3-opus-20240229",
StopReason: "end_turn",
Usage: AnthropicUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
}
}
// AnthropicTextChunks creates a complete streaming response with text deltas.
// Takes text deltas and creates all required chunks (message_start,
// content_block_start, content_block_delta for each delta,
// content_block_stop, message_delta, message_stop).
func AnthropicTextChunks(deltas ...string) []AnthropicChunk {
if len(deltas) == 0 {
return nil
}
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
model := "claude-3-opus-20240229"
chunks := []AnthropicChunk{
{
Type: "message_start",
Message: AnthropicChunkMessage{
ID: messageID,
Type: "message",
Role: "assistant",
Model: model,
},
},
{
Type: "content_block_start",
Index: 0,
ContentBlock: AnthropicContentBlock{
Type: "text",
Text: "", // According to Anthropic API spec, text should be empty in content_block_start
},
},
}
// Add a delta chunk for each delta
for _, delta := range deltas {
chunks = append(chunks, AnthropicChunk{
Type: "content_block_delta",
Index: 0,
Delta: AnthropicDeltaBlock{
Type: "text_delta",
Text: delta,
},
})
}
chunks = append(chunks,
AnthropicChunk{
Type: "content_block_stop",
Index: 0,
},
AnthropicChunk{
Type: "message_delta",
StopReason: "end_turn",
Usage: AnthropicUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
AnthropicChunk{
Type: "message_stop",
},
)
return chunks
}
// AnthropicToolCallChunks creates a complete streaming response for a tool call.
// Input JSON can be split across multiple deltas, matching Anthropic's
// input_json_delta streaming behavior.
func AnthropicToolCallChunks(toolName string, inputJSONDeltas ...string) []AnthropicChunk {
if len(inputJSONDeltas) == 0 {
return nil
}
if toolName == "" {
toolName = "tool"
}
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
model := "claude-3-opus-20240229"
toolCallID := fmt.Sprintf("toolu_%s", uuid.New().String()[:8])
chunks := []AnthropicChunk{
{
Type: "message_start",
Message: AnthropicChunkMessage{
ID: messageID,
Type: "message",
Role: "assistant",
Model: model,
},
},
{
Type: "content_block_start",
Index: 0,
ContentBlock: AnthropicContentBlock{
Type: "tool_use",
ID: toolCallID,
Name: toolName,
Input: json.RawMessage("{}"),
},
},
}
for _, delta := range inputJSONDeltas {
chunks = append(chunks, AnthropicChunk{
Type: "content_block_delta",
Index: 0,
Delta: AnthropicDeltaBlock{
Type: "input_json_delta",
PartialJSON: delta,
},
})
}
chunks = append(chunks,
AnthropicChunk{
Type: "content_block_stop",
Index: 0,
},
AnthropicChunk{
Type: "message_delta",
StopReason: "tool_use",
Usage: AnthropicUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
AnthropicChunk{
Type: "message_stop",
},
)
return chunks
}
+221
View File
@@ -0,0 +1,221 @@
package chattest_test
import (
"context"
"sync/atomic"
"testing"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chattest"
)
func TestAnthropic_Streaming(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunks("Hello", " world", "!")...,
)
})
// Create fantasy client pointing to our test server
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Say hello"},
},
},
},
}
stream, err := model.Stream(ctx, call)
require.NoError(t, err)
expectedDeltas := []string{"Hello", " world", "!"}
deltaIndex := 0
var allParts []fantasy.StreamPart
for part := range stream {
allParts = append(allParts, part)
if part.Type == fantasy.StreamPartTypeTextDelta {
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
deltaIndex++
}
}
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d. Total parts received: %d", len(expectedDeltas), deltaIndex, len(allParts))
}
func TestAnthropic_ToolCalls(t *testing.T) {
t.Parallel()
var requestCount atomic.Int32
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
switch requestCount.Add(1) {
case 1:
return chattest.AnthropicStreamingResponse(
chattest.AnthropicToolCallChunks("get_weather", `{"location":"San Francisco"}`)...,
)
default:
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunks("The weather in San Francisco is 72F.")...,
)
}
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
type weatherInput struct {
Location string `json:"location"`
}
var toolCallCount atomic.Int32
weatherTool := fantasy.NewAgentTool(
"get_weather",
"Get weather for a location.",
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolCallCount.Add(1)
require.Equal(t, "San Francisco", input.Location)
return fantasy.NewTextResponse("72F"), nil
},
)
agent := fantasy.NewAgent(
model,
fantasy.WithSystemPrompt("You are a helpful assistant."),
fantasy.WithTools(weatherTool),
)
result, err := agent.Stream(context.Background(), fantasy.AgentStreamCall{
Prompt: "What's the weather in San Francisco?",
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
}
func TestAnthropic_NonStreaming(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicNonStreamingResponse("Response text")
})
// Create fantasy client pointing to our test server
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Test message"},
},
},
},
}
response, err := model.Generate(ctx, call)
require.NoError(t, err)
require.NotNil(t, response)
}
func TestAnthropic_Streaming_MismatchReturnsErrorPart(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicNonStreamingResponse("wrong response type")
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.NoError(t, err)
var streamErr error
for part := range stream {
if part.Type == fantasy.StreamPartTypeError {
streamErr = part.Error
break
}
}
require.Error(t, streamErr)
require.Contains(t, streamErr.Error(), "500 Internal Server Error")
}
func TestAnthropic_NonStreaming_MismatchReturnsError(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunks("wrong", " response")...,
)
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "500 Internal Server Error")
}
+74
View File
@@ -0,0 +1,74 @@
package chattest
import (
"encoding/json"
"net/http"
)
// ErrorResponse describes an HTTP error that a test server should return
// instead of a normal streaming or JSON response.
type ErrorResponse struct {
StatusCode int
Type string
Message string
}
// writeErrorResponse writes a JSON error response matching the common
// provider error format used by both Anthropic and OpenAI.
func writeErrorResponse(w http.ResponseWriter, errResp *ErrorResponse) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(errResp.StatusCode)
body := map[string]interface{}{
"error": map[string]interface{}{
"type": errResp.Type,
"message": errResp.Message,
},
}
_ = json.NewEncoder(w).Encode(body)
}
// AnthropicErrorResponse returns an AnthropicResponse that causes the
// test server to respond with the given HTTP status code and error.
// This simulates provider errors like 529 Overloaded or 429 Rate Limited.
func AnthropicErrorResponse(statusCode int, errorType, message string) AnthropicResponse {
return AnthropicResponse{
Error: &ErrorResponse{
StatusCode: statusCode,
Type: errorType,
Message: message,
},
}
}
// AnthropicOverloadedResponse returns a 529 "overloaded" error matching
// Anthropic's overloaded response format.
func AnthropicOverloadedResponse() AnthropicResponse {
return AnthropicErrorResponse(529, "overloaded_error", "Overloaded")
}
// AnthropicRateLimitResponse returns a 429 rate limit error.
func AnthropicRateLimitResponse() AnthropicResponse {
return AnthropicErrorResponse(http.StatusTooManyRequests, "rate_limit_error", "Rate limited")
}
// OpenAIErrorResponse returns an OpenAIResponse that causes the
// test server to respond with the given HTTP status code and error.
func OpenAIErrorResponse(statusCode int, errorType, message string) OpenAIResponse {
return OpenAIResponse{
Error: &ErrorResponse{
StatusCode: statusCode,
Type: errorType,
Message: message,
},
}
}
// OpenAIRateLimitResponse returns a 429 rate limit error.
func OpenAIRateLimitResponse() OpenAIResponse {
return OpenAIErrorResponse(http.StatusTooManyRequests, "rate_limit_exceeded", "Rate limit exceeded")
}
// OpenAIServerErrorResponse returns a 500 internal server error.
func OpenAIServerErrorResponse() OpenAIResponse {
return OpenAIErrorResponse(http.StatusInternalServerError, "server_error", "Internal server error")
}
+490
View File
@@ -0,0 +1,490 @@
package chattest
import (
"encoding/json"
"fmt"
"log"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/google/uuid"
)
// OpenAIHandler handles OpenAI API requests and returns a response.
type OpenAIHandler func(req *OpenAIRequest) OpenAIResponse
// OpenAIResponse represents a response to an OpenAI request.
// Either StreamingChunks or Response should be set, not both.
type OpenAIResponse struct {
StreamingChunks <-chan OpenAIChunk
Response *OpenAICompletion
Error *ErrorResponse // If set, server returns this HTTP error instead of streaming/JSON.
}
// OpenAIRequest represents an OpenAI chat completion request.
type OpenAIRequest struct {
*http.Request
Model string `json:"model"`
Messages []OpenAIMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // For responses API
// TODO: encoding/json ignores inline tags. Add custom UnmarshalJSON to capture unknown keys.
Options map[string]interface{} `json:",inline"` //nolint:revive
}
// OpenAIMessage represents a message in an OpenAI request.
type OpenAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// OpenAIToolCallFunction represents the function details in a tool call.
type OpenAIToolCallFunction struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
// OpenAIToolCall represents a tool call in a streaming chunk or completion.
type OpenAIToolCall struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function OpenAIToolCallFunction `json:"function,omitempty"`
Index int `json:"index,omitempty"` // For streaming deltas
}
// OpenAIChunkChoice represents a choice in a streaming chunk.
type OpenAIChunkChoice struct {
Index int `json:"index"`
Delta string `json:"delta,omitempty"`
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
// OpenAIChunk represents a streaming chunk from OpenAI.
type OpenAIChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAIChunkChoice `json:"choices"`
}
// OpenAICompletionChoice represents a choice in a completion response.
type OpenAICompletionChoice struct {
Index int `json:"index"`
Message OpenAIMessage `json:"message"`
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
}
// OpenAICompletionUsage represents usage information in a completion response.
type OpenAICompletionUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// OpenAICompletion represents a non-streaming OpenAI completion response.
type OpenAICompletion struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAICompletionChoice `json:"choices"`
Usage OpenAICompletionUsage `json:"usage"`
}
// openAIServer is a test server that mocks the OpenAI API.
type openAIServer struct {
mu sync.Mutex
server *httptest.Server
handler OpenAIHandler
request *OpenAIRequest
}
// NewOpenAI creates a new OpenAI test server with a handler function.
// The handler is called for each request and should return either a streaming
// response (via channel) or a non-streaming response.
// Returns the base URL of the server.
func NewOpenAI(t testing.TB, handler OpenAIHandler) string {
t.Helper()
s := &openAIServer{
handler: handler,
}
mux := http.NewServeMux()
mux.HandleFunc("POST /chat/completions", s.handleChatCompletions)
mux.HandleFunc("POST /responses", s.handleResponses)
s.server = httptest.NewServer(mux)
t.Cleanup(func() {
s.server.Close()
})
return s.server.URL
}
func (s *openAIServer) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
var req OpenAIRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
req.Request = r
s.mu.Lock()
s.request = &req
s.mu.Unlock()
resp := s.handler(&req)
s.writeChatCompletionsResponse(w, &req, resp)
}
func (s *openAIServer) handleResponses(w http.ResponseWriter, r *http.Request) {
var req OpenAIRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
req.Request = r
s.mu.Lock()
s.request = &req
s.mu.Unlock()
resp := s.handler(&req)
s.writeResponsesAPIResponse(w, &req, resp)
}
func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
if resp.Error != nil {
writeErrorResponse(w, resp.Error)
return
}
hasStreaming := resp.StreamingChunks != nil
hasNonStreaming := resp.Response != nil
switch {
case hasStreaming && hasNonStreaming:
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
return
case !hasStreaming && !hasNonStreaming:
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
return
case req.Stream && !hasStreaming:
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
return
case !req.Stream && !hasNonStreaming:
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
writeChatCompletionsStreaming(w, req.Request, resp.StreamingChunks)
default:
s.writeChatCompletionsNonStreaming(w, resp.Response)
}
}
func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
if resp.Error != nil {
writeErrorResponse(w, resp.Error)
return
}
hasStreaming := resp.StreamingChunks != nil
hasNonStreaming := resp.Response != nil
switch {
case hasStreaming && hasNonStreaming:
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
return
case !hasStreaming && !hasNonStreaming:
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
return
case req.Stream && !hasStreaming:
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
return
case !req.Stream && !hasNonStreaming:
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
writeResponsesAPIStreaming(w, req.Request, resp.StreamingChunks)
default:
s.writeResponsesAPINonStreaming(w, resp.Response)
}
}
func writeChatCompletionsStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
for {
var chunk OpenAIChunk
var ok bool
select {
case <-r.Context().Done():
log.Printf("writeChatCompletionsStreaming: request context canceled, stopping stream")
return
case chunk, ok = <-chunks:
if !ok {
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
return
}
}
choicesData := make([]map[string]interface{}, len(chunk.Choices))
for i, choice := range chunk.Choices {
choiceData := map[string]interface{}{
"index": choice.Index,
}
if choice.Delta != "" {
choiceData["delta"] = map[string]interface{}{
"content": choice.Delta,
}
}
if len(choice.ToolCalls) > 0 {
// Tool calls come in the delta
if choiceData["delta"] == nil {
choiceData["delta"] = make(map[string]interface{})
}
delta, ok := choiceData["delta"].(map[string]interface{})
if !ok {
delta = make(map[string]interface{})
choiceData["delta"] = delta
}
delta["tool_calls"] = choice.ToolCalls
}
if choice.FinishReason != "" {
choiceData["finish_reason"] = choice.FinishReason
}
choicesData[i] = choiceData
}
chunkData := map[string]interface{}{
"id": chunk.ID,
"object": chunk.Object,
"created": chunk.Created,
"model": chunk.Model,
"choices": choicesData,
}
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
return
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
return
}
flusher.Flush()
}
}
func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
itemIDs := make(map[int]string)
for {
var chunk OpenAIChunk
var ok bool
select {
case <-r.Context().Done():
log.Printf("writeResponsesAPIStreaming: request context canceled, stopping stream")
return
case chunk, ok = <-chunks:
if !ok {
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
return
}
}
// Responses API sends one event per choice
for outputIndex, choice := range chunk.Choices {
if choice.Index != 0 {
outputIndex = choice.Index
}
itemID, found := itemIDs[outputIndex]
if !found {
itemID = fmt.Sprintf("msg_%s", uuid.New().String()[:8])
itemIDs[outputIndex] = itemID
}
chunkData := map[string]interface{}{
"type": "response.output_text.delta",
"item_id": itemID,
"output_index": outputIndex,
"created": chunk.Created,
"model": chunk.Model,
"content_index": 0,
"delta": choice.Delta,
}
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
return
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
return
}
flusher.Flush()
}
}
}
func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}
func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
_ = s // receiver unused but kept for consistency
// Convert all choices to output format
outputs := make([]map[string]interface{}, len(resp.Choices))
for i, choice := range resp.Choices {
outputs[i] = map[string]interface{}{
"id": uuid.New().String(),
"type": "message",
"role": "assistant",
"content": []map[string]interface{}{
{
"type": "output_text",
"text": choice.Message.Content,
},
},
}
}
response := map[string]interface{}{
"id": resp.ID,
"object": "response",
"created": resp.Created,
"model": resp.Model,
"output": outputs,
"usage": resp.Usage,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(response)
}
// OpenAIStreamingResponse creates a streaming response from chunks.
func OpenAIStreamingResponse(chunks ...OpenAIChunk) OpenAIResponse {
ch := make(chan OpenAIChunk, len(chunks))
go func() {
for _, chunk := range chunks {
ch <- chunk
}
close(ch)
}()
return OpenAIResponse{StreamingChunks: ch}
}
// OpenAINonStreamingResponse creates a non-streaming response with the given text.
func OpenAINonStreamingResponse(text string) OpenAIResponse {
return OpenAIResponse{
Response: &OpenAICompletion{
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: "gpt-4",
Choices: []OpenAICompletionChoice{
{
Index: 0,
Message: OpenAIMessage{
Role: "assistant",
Content: text,
},
FinishReason: "stop",
},
},
Usage: OpenAICompletionUsage{
PromptTokens: 10,
CompletionTokens: 5,
TotalTokens: 15,
},
},
}
}
// OpenAITextChunks creates streaming chunks with text deltas.
// Each delta string becomes a separate chunk with a single choice.
// Returns a slice of chunks, one per delta, with each choice having its index (0, 1, 2, ...).
func OpenAITextChunks(deltas ...string) []OpenAIChunk {
if len(deltas) == 0 {
return nil
}
chunkID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8])
now := time.Now().Unix()
chunks := make([]OpenAIChunk, len(deltas))
for i, delta := range deltas {
chunks[i] = OpenAIChunk{
ID: chunkID,
Object: "chat.completion.chunk",
Created: now,
Model: "gpt-4",
Choices: []OpenAIChunkChoice{
{
Index: i,
Delta: delta,
},
},
}
}
return chunks
}
// OpenAIToolCallChunk creates a streaming chunk with a tool call.
// Takes the tool name and arguments JSON string, creates a tool call for choice index 0.
func OpenAIToolCallChunk(toolName, arguments string) OpenAIChunk {
return OpenAIChunk{
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: "gpt-4",
Choices: []OpenAIChunkChoice{
{
Index: 0,
ToolCalls: []OpenAIToolCall{
{
Index: 0,
ID: fmt.Sprintf("call_%s", uuid.New().String()[:8]),
Type: "function",
Function: OpenAIToolCallFunction{
Name: toolName,
Arguments: arguments,
},
},
},
},
},
}
}
+367
View File
@@ -0,0 +1,367 @@
package chattest_test
import (
"context"
"sync/atomic"
"testing"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chattest"
)
func TestOpenAI_Streaming(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(
append(
append(
chattest.OpenAITextChunks("Hello", "Hi"),
chattest.OpenAITextChunks(" world", " there")...,
),
chattest.OpenAITextChunks("!", "!")...,
)...,
)
})
// Create fantasy client pointing to our test server
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Say hello"},
},
},
},
}
stream, err := model.Stream(ctx, call)
require.NoError(t, err)
// We expect chunks in order: one choice per chunk
// So we get: "Hello" (choice 0), "Hi" (choice 1), " world" (choice 0), " there" (choice 1), "!" (choice 0), "!" (choice 1)
expectedDeltas := []string{"Hello", "Hi", " world", " there", "!", "!"}
deltaIndex := 0
for part := range stream {
if part.Type == fantasy.StreamPartTypeTextDelta {
// Verify we're getting deltas in the expected order
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
deltaIndex++
}
}
// Verify we received all expected deltas
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d", len(expectedDeltas), deltaIndex)
}
func TestOpenAI_Streaming_ResponsesAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(
append(
append(
chattest.OpenAITextChunks("First", "Second"),
chattest.OpenAITextChunks(" output", " output")...,
),
chattest.OpenAITextChunks("!", "!")...,
)...,
)
})
// Create fantasy client pointing to our test server (responses API)
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
fantasyopenai.WithUseResponsesAPI(),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Say hello"},
},
},
},
}
stream, err := model.Stream(ctx, call)
require.NoError(t, err)
var parts []fantasy.StreamPart
for part := range stream {
parts = append(parts, part)
}
// Verify we received the chunks in order
require.Greater(t, len(parts), 0)
// Extract text deltas from parts and verify they match expected chunks in order
// We expect: "First", " output", "!" for choice 0, and "Second", " output", "!" for choice 1
var allDeltas []string
for _, part := range parts {
if part.Type == fantasy.StreamPartTypeTextDelta {
allDeltas = append(allDeltas, part.Delta)
}
}
// Verify we received deltas (responses API may handle multiple choices differently)
// If we got text deltas, verify the content
if len(allDeltas) > 0 {
allText := ""
for _, delta := range allDeltas {
allText += delta
}
require.Contains(t, allText, "First")
require.Contains(t, allText, "Second")
require.Contains(t, allText, "output")
require.Contains(t, allText, "!")
} else {
// If no text deltas, at least verify we got some parts (may be different format)
require.Greater(t, len(parts), 0, "Expected at least one stream part")
}
}
func TestOpenAI_NonStreaming_CompletionsAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAINonStreamingResponse("First response")
})
// Create fantasy client pointing to our test server (completions API)
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Test message"},
},
},
},
}
response, err := model.Generate(ctx, call)
require.NoError(t, err)
require.NotNil(t, response)
}
func TestOpenAI_ToolCalls(t *testing.T) {
t.Parallel()
var requestCount atomic.Int32
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
switch requestCount.Add(1) {
case 1:
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("get_weather", `{"location":"San Francisco"}`),
)
default:
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("The weather in San Francisco is 72F.")...,
)
}
})
// Create fantasy client pointing to our test server
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
type weatherInput struct {
Location string `json:"location"`
}
var toolCallCount atomic.Int32
weatherTool := fantasy.NewAgentTool(
"get_weather",
"Get weather for a location.",
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolCallCount.Add(1)
require.Equal(t, "San Francisco", input.Location)
return fantasy.NewTextResponse("72F"), nil
},
)
agent := fantasy.NewAgent(
model,
fantasy.WithSystemPrompt("You are a helpful assistant."),
fantasy.WithTools(weatherTool),
)
result, err := agent.Stream(ctx, fantasy.AgentStreamCall{
Prompt: "What's the weather in San Francisco?",
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
}
func TestOpenAI_NonStreaming_ResponsesAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAINonStreamingResponse("First output")
})
// Create fantasy client pointing to our test server (responses API)
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
fantasyopenai.WithUseResponsesAPI(),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Test message"},
},
},
},
}
response, err := model.Generate(ctx, call)
require.NoError(t, err)
require.NotNil(t, response)
}
func TestOpenAI_Streaming_MismatchReturnsErrorPart(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAINonStreamingResponse("wrong response type")
})
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "gpt-4")
require.NoError(t, err)
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.NoError(t, err)
var streamErr error
for part := range stream {
if part.Type == fantasy.StreamPartTypeError {
streamErr = part.Error
break
}
}
require.Error(t, streamErr)
require.Contains(t, streamErr.Error(), "non-streaming response for streaming request")
}
func TestOpenAI_NonStreaming_MismatchReturnsError_CompletionsAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
})
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "gpt-4")
require.NoError(t, err)
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "streaming response for non-streaming request")
}
func TestOpenAI_NonStreaming_MismatchReturnsError_ResponsesAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
})
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
fantasyopenai.WithUseResponsesAPI(),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "gpt-4")
require.NoError(t, err)
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "streaming response for non-streaming request")
}
+33
View File
@@ -0,0 +1,33 @@
package chattool
import (
"encoding/json"
"unicode/utf8"
"charm.land/fantasy"
)
// toolResponse builds a fantasy.ToolResponse from a JSON-serializable
// result payload.
func toolResponse(result map[string]any) fantasy.ToolResponse {
data, err := json.Marshal(result)
if err != nil {
return fantasy.NewTextResponse("{}")
}
return fantasy.NewTextResponse(string(data))
}
func truncateRunes(value string, maxLen int) string {
if maxLen <= 0 || value == "" {
return ""
}
if utf8.RuneCountInString(value) <= maxLen {
return value
}
runes := []rune(value)
if maxLen > len(runes) {
maxLen = len(runes)
}
return string(runes[:maxLen])
}
+423
View File
@@ -0,0 +1,423 @@
package chattool
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/namesgenerator"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// buildPollInterval is how often we check if the workspace
// build has completed.
buildPollInterval = 2 * time.Second
// buildTimeout is the maximum time to wait for a workspace
// build to complete before giving up.
buildTimeout = 10 * time.Minute
// agentConnectTimeout is the maximum time to wait for the
// workspace agent to become reachable after a successful build.
agentConnectTimeout = 2 * time.Minute
// agentRetryInterval is how often we retry connecting to the
// workspace agent.
agentRetryInterval = 2 * time.Second
// agentAttemptTimeout is the timeout for a single connection
// attempt to the workspace agent during the retry loop.
agentAttemptTimeout = 5 * time.Second
// agentPingTimeout is the timeout for a single agent ping
// when checking whether an existing workspace is alive.
agentPingTimeout = 5 * time.Second
)
// CreateWorkspaceFn creates a workspace for the given owner.
type CreateWorkspaceFn func(
ctx context.Context,
ownerID uuid.UUID,
req codersdk.CreateWorkspaceRequest,
) (codersdk.Workspace, error)
// AgentConnFunc provides access to workspace agent connections.
type AgentConnFunc func(
ctx context.Context,
agentID uuid.UUID,
) (workspacesdk.AgentConn, func(), error)
// CreateWorkspaceOptions configures the create_workspace tool.
type CreateWorkspaceOptions struct {
DB database.Store
OwnerID uuid.UUID
ChatID uuid.UUID
CreateFn CreateWorkspaceFn
AgentConnFn AgentConnFunc
WorkspaceMu *sync.Mutex
}
type createWorkspaceArgs struct {
TemplateID string `json:"template_id"`
Name string `json:"name,omitempty"`
Parameters map[string]string `json:"parameters,omitempty"`
}
// CreateWorkspace returns a tool that creates a new workspace from a
// template. The tool is idempotent: if the chat already has a
// workspace that is building or running, it returns the existing
// workspace instead of creating a new one. A mutex prevents parallel
// calls from creating duplicate workspaces.
func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"create_workspace",
"Create a new workspace from a template. Requires a "+
"template_id (from list_templates). Optionally provide "+
"a name and parameter values (from read_template). "+
"If no name is given, one will be generated. "+
"This tool is idempotent — if the chat already has a "+
"workspace that is building or running, the existing "+
"workspace is returned.",
func(ctx context.Context, args createWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.CreateFn == nil {
return fantasy.NewTextErrorResponse("workspace creator is not configured"), nil
}
templateIDStr := strings.TrimSpace(args.TemplateID)
if templateIDStr == "" {
return fantasy.NewTextErrorResponse("template_id is required; use list_templates to find one"), nil
}
templateID, err := uuid.Parse(templateIDStr)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("invalid template_id: %w", err).Error(),
), nil
}
// Serialize workspace creation to prevent parallel
// tool calls from creating duplicate workspaces.
if options.WorkspaceMu != nil {
options.WorkspaceMu.Lock()
defer options.WorkspaceMu.Unlock()
}
// Check for an existing workspace on the chat.
if options.DB != nil && options.ChatID != uuid.Nil {
existing, done, existErr := checkExistingWorkspace(
ctx, options.DB, options.ChatID,
options.AgentConnFn,
)
if existErr != nil {
return fantasy.NewTextErrorResponse(existErr.Error()), nil
}
if done {
return toolResponse(existing), nil
}
}
ownerID := options.OwnerID
// Set up dbauthz context for DB lookups.
if options.DB != nil {
ownerCtx, ownerErr := asOwner(ctx, options.DB, ownerID)
if ownerErr != nil {
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
}
ctx = ownerCtx
}
createReq := codersdk.CreateWorkspaceRequest{
TemplateID: templateID,
}
// Resolve workspace name.
name := strings.TrimSpace(args.Name)
if name == "" {
seed := "workspace"
if options.DB != nil {
if t, lookupErr := options.DB.GetTemplateByID(ctx, templateID); lookupErr == nil {
seed = t.Name
}
}
name = generatedWorkspaceName(seed)
} else if err := codersdk.NameValid(name); err != nil {
name = generatedWorkspaceName(name)
}
createReq.Name = name
// Map parameters.
for k, v := range args.Parameters {
createReq.RichParameterValues = append(
createReq.RichParameterValues,
codersdk.WorkspaceBuildParameter{Name: k, Value: v},
)
}
workspace, err := options.CreateFn(ctx, ownerID, createReq)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
// Wait for the build to complete and the agent to
// come online so subsequent tools can use the
// workspace immediately.
if options.DB != nil {
if err := waitForBuild(ctx, options.DB, workspace.ID); err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("workspace build failed: %w", err).Error(),
), nil
}
}
// Look up the first agent so we can link it to the chat.
workspaceAgentID := uuid.Nil
if options.DB != nil {
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
if agentErr == nil && len(agents) > 0 {
workspaceAgentID = agents[0].ID
}
}
// Persist workspace + agent association on the chat.
if options.DB != nil && options.ChatID != uuid.Nil {
_, _ = options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
ID: options.ChatID,
WorkspaceID: uuid.NullUUID{
UUID: workspace.ID,
Valid: true,
},
})
}
// Wait for the agent to come online.
if workspaceAgentID != uuid.Nil && options.AgentConnFn != nil {
if err := waitForAgent(ctx, options.AgentConnFn, workspaceAgentID); err != nil {
// Non-fatal: the workspace was created
// successfully, the agent just isn't ready
// yet. The model can retry.
return toolResponse(map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
"agent_status": "not_ready",
"agent_error": err.Error(),
}), nil
}
}
return toolResponse(map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
}), nil
},
)
}
// checkExistingWorkspace checks whether the chat already has a usable
// workspace. Returns the result map and true if the caller should
// return early (workspace exists and is alive or building). Returns
// false if the caller should proceed with creation (workspace is dead
// or missing).
func checkExistingWorkspace(
ctx context.Context,
db database.Store,
chatID uuid.UUID,
agentConnFn AgentConnFunc,
) (map[string]any, bool, error) {
chat, err := db.GetChatByID(ctx, chatID)
if err != nil {
return nil, false, xerrors.Errorf("load chat: %w", err)
}
if !chat.WorkspaceID.Valid {
return nil, false, nil
}
// Check if workspace still exists.
ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
// Workspace was deleted — allow creation.
return nil, false, nil
}
return nil, false, xerrors.Errorf("load workspace: %w", err)
}
// Check the latest build status.
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
if err != nil {
// Can't determine status — allow creation.
return nil, false, nil
}
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
if err != nil {
return nil, false, nil
}
switch job.JobStatus {
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning:
// Build is in progress — wait for it instead of
// creating a new workspace.
if err := waitForBuild(ctx, db, ws.ID); err != nil {
return nil, false, xerrors.Errorf(
"existing workspace build failed: %w", err,
)
}
return map[string]any{
"created": false,
"workspace_name": ws.Name,
"status": "already_exists",
"message": "workspace was already being built and is now ready",
}, true, nil
case database.ProvisionerJobStatusSucceeded:
// Build succeeded — check if agent is reachable.
agents, agentsErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID)
if agentsErr == nil && len(agents) > 0 && agentConnFn != nil {
pingCtx, cancel := context.WithTimeout(
ctx, agentPingTimeout,
)
defer cancel()
conn, release, connErr := agentConnFn(
pingCtx, agents[0].ID,
)
if connErr == nil {
release()
_ = conn
return map[string]any{
"created": false,
"workspace_name": ws.Name,
"status": "already_exists",
"message": "workspace is already running and reachable",
}, true, nil
}
// Agent unreachable — workspace is dead, allow
// creation.
}
// No agent ID or no conn func — allow creation.
return nil, false, nil
default:
// Failed, canceled, etc — allow creation.
return nil, false, nil
}
}
// waitForBuild polls the workspace's latest build until it
// completes or the context expires.
func waitForBuild(
ctx context.Context,
db database.Store,
workspaceID uuid.UUID,
) error {
buildCtx, cancel := context.WithTimeout(ctx, buildTimeout)
defer cancel()
ticker := time.NewTicker(buildPollInterval)
defer ticker.Stop()
for {
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(
buildCtx, workspaceID,
)
if err != nil {
return xerrors.Errorf("get latest build: %w", err)
}
job, err := db.GetProvisionerJobByID(buildCtx, build.JobID)
if err != nil {
return xerrors.Errorf("get provisioner job: %w", err)
}
switch job.JobStatus {
case database.ProvisionerJobStatusSucceeded:
return nil
case database.ProvisionerJobStatusFailed:
errMsg := "build failed"
if job.Error.Valid {
errMsg = job.Error.String
}
return xerrors.New(errMsg)
case database.ProvisionerJobStatusCanceled:
return xerrors.New("build was canceled")
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning,
database.ProvisionerJobStatusCanceling:
// Still in progress — keep waiting.
default:
return xerrors.Errorf("unexpected job status: %s", job.JobStatus)
}
select {
case <-buildCtx.Done():
return xerrors.Errorf(
"timed out waiting for workspace build: %w",
buildCtx.Err(),
)
case <-ticker.C:
}
}
}
// waitForAgent retries connecting to the workspace agent until it
// succeeds or the timeout expires.
func waitForAgent(
ctx context.Context,
agentConnFn AgentConnFunc,
agentID uuid.UUID,
) error {
agentCtx, cancel := context.WithTimeout(ctx, agentConnectTimeout)
defer cancel()
ticker := time.NewTicker(agentRetryInterval)
defer ticker.Stop()
var lastErr error
for {
attemptCtx, attemptCancel := context.WithTimeout(agentCtx, agentAttemptTimeout)
conn, release, err := agentConnFn(attemptCtx, agentID)
attemptCancel()
if err == nil {
release()
_ = conn
return nil
}
lastErr = err
select {
case <-agentCtx.Done():
return xerrors.Errorf(
"timed out waiting for workspace agent: %w",
lastErr,
)
case <-ticker.C:
}
}
}
func generatedWorkspaceName(seed string) string {
base := codersdk.UsernameFrom(strings.TrimSpace(strings.ToLower(seed)))
if strings.TrimSpace(base) == "" {
base = "workspace"
}
suffix := strings.ReplaceAll(uuid.NewString(), "-", "")[:4]
if len(base) > 27 {
base = strings.Trim(base[:27], "-")
}
if base == "" {
base = "workspace"
}
name := fmt.Sprintf("%s-%s", base, suffix)
if err := codersdk.NameValid(name); err == nil {
return name
}
return namesgenerator.NameDigitWith("-")
}
+50
View File
@@ -0,0 +1,50 @@
package chattool
import (
"context"
"charm.land/fantasy"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type EditFilesOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
}
type EditFilesArgs struct {
Files []workspacesdk.FileEdits `json:"files"`
}
func EditFiles(options EditFilesOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"edit_files",
"Perform search-and-replace edits on one or more files in the workspace."+
" Each file can have multiple edits applied atomically.",
func(ctx context.Context, args EditFilesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeEditFilesTool(ctx, conn, args)
},
)
}
func executeEditFilesTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args EditFilesArgs,
) (fantasy.ToolResponse, error) {
if len(args.Files) == 0 {
return fantasy.NewTextErrorResponse("files is required"), nil
}
if err := conn.EditFiles(ctx, workspacesdk.FileEditRequest{Files: args.Files}); err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolResponse(map[string]any{"ok": true}), nil
}
+451
View File
@@ -0,0 +1,451 @@
package chattool
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
"time"
"charm.land/fantasy"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// defaultTimeout is the default timeout for command
// execution.
defaultTimeout = 10 * time.Second
// maxOutputToModel is the maximum output sent to the LLM.
maxOutputToModel = 32 << 10 // 32KB
// pollInterval is how often we check for process completion
// in foreground mode.
pollInterval = 200 * time.Millisecond
)
// nonInteractiveEnvVars are set on every process to prevent
// interactive prompts that would hang a headless execution.
var nonInteractiveEnvVars = map[string]string{
"GIT_EDITOR": "true",
"GIT_SEQUENCE_EDITOR": "true",
"EDITOR": "true",
"VISUAL": "true",
"GIT_TERMINAL_PROMPT": "0",
"NO_COLOR": "1",
"TERM": "dumb",
"PAGER": "cat",
"GIT_PAGER": "cat",
}
// fileDumpPatterns detects commands that dump entire files.
// When matched, a note is added suggesting read_file instead.
var fileDumpPatterns = []*regexp.Regexp{
regexp.MustCompile(`^cat\s+`),
regexp.MustCompile(`^(rg|grep)\s+.*--include-all`),
regexp.MustCompile(`^(rg|grep)\s+-l\s+`),
}
// ExecuteResult is the structured response from the execute
// tool.
type ExecuteResult struct {
Success bool `json:"success"`
Output string `json:"output,omitempty"`
ExitCode int `json:"exit_code"`
WallDurationMs int64 `json:"wall_duration_ms"`
Error string `json:"error,omitempty"`
Truncated *workspacesdk.ProcessTruncation `json:"truncated,omitempty"`
Note string `json:"note,omitempty"`
BackgroundProcessID string `json:"background_process_id,omitempty"`
}
// ExecuteOptions configures the execute tool.
type ExecuteOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
DefaultTimeout time.Duration
ChatID string
}
// ProcessToolOptions configures a process management tool
// (process_output, process_list, or process_signal). Each of
// these tools only needs a workspace connection resolver.
type ProcessToolOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
}
// ExecuteArgs are the parameters accepted by the execute tool.
type ExecuteArgs struct {
Command string `json:"command"`
Timeout *string `json:"timeout,omitempty"`
WorkDir *string `json:"workdir,omitempty"`
RunInBackground *bool `json:"run_in_background,omitempty"`
}
// Execute returns an AgentTool that runs a shell command in the
// workspace via the agent HTTP API.
func Execute(options ExecuteOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"execute",
"Execute a shell command in the workspace.",
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeTool(ctx, conn, args, options.DefaultTimeout, options.ChatID), nil
},
)
}
func executeTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args ExecuteArgs,
optTimeout time.Duration,
chatID string,
) fantasy.ToolResponse {
if args.Command == "" {
return fantasy.NewTextErrorResponse("command is required")
}
// Build the environment map for the process request.
env := make(map[string]string, len(nonInteractiveEnvVars)+1)
env["CODER_CHAT_AGENT"] = "true"
if chatID != "" {
env["CODER_CHAT_ID"] = chatID
}
for k, v := range nonInteractiveEnvVars {
env[k] = v
}
background := args.RunInBackground != nil && *args.RunInBackground
var workDir string
if args.WorkDir != nil {
workDir = *args.WorkDir
}
if background {
return executeBackground(ctx, conn, args.Command, workDir, env)
}
return executeForeground(ctx, conn, args, optTimeout, workDir, env)
}
// executeBackground starts a process in the background and
// returns immediately with the process ID.
func executeBackground(
ctx context.Context,
conn workspacesdk.AgentConn,
command string,
workDir string,
env map[string]string,
) fantasy.ToolResponse {
resp, err := conn.StartProcess(ctx, workspacesdk.StartProcessRequest{
Command: command,
WorkDir: workDir,
Env: env,
Background: true,
})
if err != nil {
return errorResult(fmt.Sprintf("start background process: %v", err))
}
result := ExecuteResult{
Success: true,
BackgroundProcessID: resp.ID,
}
data, err := json.Marshal(result)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error())
}
return fantasy.NewTextResponse(string(data))
}
// executeForeground starts a process and polls for its
// completion, enforcing the configured timeout.
func executeForeground(
ctx context.Context,
conn workspacesdk.AgentConn,
args ExecuteArgs,
optTimeout time.Duration,
workDir string,
env map[string]string,
) fantasy.ToolResponse {
timeout := optTimeout
if timeout <= 0 {
timeout = defaultTimeout
}
if args.Timeout != nil {
parsed, err := time.ParseDuration(*args.Timeout)
if err != nil {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("invalid timeout %q: %v", *args.Timeout, err),
)
}
timeout = parsed
}
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
start := time.Now()
resp, err := conn.StartProcess(cmdCtx, workspacesdk.StartProcessRequest{
Command: args.Command,
WorkDir: workDir,
Env: env,
Background: false,
})
if err != nil {
return errorResult(fmt.Sprintf("start process: %v", err))
}
result := pollProcess(cmdCtx, conn, resp.ID, timeout)
result.WallDurationMs = time.Since(start).Milliseconds()
// Add an advisory note for file-dump commands.
if note := detectFileDump(args.Command); note != "" {
result.Note = note
}
data, err := json.Marshal(result)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error())
}
return fantasy.NewTextResponse(string(data))
}
// truncateOutput safely truncates output to maxOutputToModel,
// ensuring the result is valid UTF-8 even if the cut falls in
// the middle of a multi-byte character.
func truncateOutput(output string) string {
if len(output) > maxOutputToModel {
output = strings.ToValidUTF8(output[:maxOutputToModel], "")
}
return output
}
// pollProcess polls for process output until the process exits
// or the context times out.
func pollProcess(
ctx context.Context,
conn workspacesdk.AgentConn,
processID string,
timeout time.Duration,
) ExecuteResult {
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
// Timeout — get whatever output we have. Use a
// fresh context since cmdCtx is already canceled.
bgCtx, bgCancel := context.WithTimeout(
context.Background(),
5*time.Second,
)
outputResp, _ := conn.ProcessOutput(bgCtx, processID)
bgCancel()
output := truncateOutput(outputResp.Output)
return ExecuteResult{
Success: false,
Output: output,
ExitCode: -1,
Error: fmt.Sprintf("command timed out after %s", timeout),
Truncated: outputResp.Truncated,
}
case <-ticker.C:
outputResp, err := conn.ProcessOutput(ctx, processID)
if err != nil {
return ExecuteResult{
Success: false,
Error: fmt.Sprintf("get process output: %v", err),
}
}
if !outputResp.Running {
exitCode := 0
if outputResp.ExitCode != nil {
exitCode = *outputResp.ExitCode
}
output := truncateOutput(outputResp.Output)
return ExecuteResult{
Success: exitCode == 0,
Output: output,
ExitCode: exitCode,
Truncated: outputResp.Truncated,
}
}
}
}
}
// errorResult builds a ToolResponse from an ExecuteResult with
// an error message.
func errorResult(msg string) fantasy.ToolResponse {
data, err := json.Marshal(ExecuteResult{
Success: false,
Error: msg,
})
if err != nil {
return fantasy.NewTextErrorResponse(msg)
}
return fantasy.NewTextResponse(string(data))
}
// detectFileDump checks whether the command matches a file-dump
// pattern and returns an advisory note, or empty string if no
// match.
func detectFileDump(command string) string {
for _, pat := range fileDumpPatterns {
if pat.MatchString(command) {
return "Consider using read_file instead of " +
"dumping file contents with shell commands."
}
}
return ""
}
// ProcessOutputArgs are the parameters accepted by the
// process_output tool.
type ProcessOutputArgs struct {
ProcessID string `json:"process_id"`
}
// ProcessOutput returns an AgentTool that retrieves the output
// of a background process by its ID.
func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"process_output",
"Retrieve output from a background process. "+
"Use the process_id returned by execute with "+
"run_in_background=true. Returns the current output, "+
"whether the process is still running, and the exit "+
"code if it has finished.",
func(ctx context.Context, args ProcessOutputArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
if args.ProcessID == "" {
return fantasy.NewTextErrorResponse("process_id is required"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
resp, err := conn.ProcessOutput(ctx, args.ProcessID)
if err != nil {
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
}
output := truncateOutput(resp.Output)
exitCode := 0
if resp.ExitCode != nil {
exitCode = *resp.ExitCode
}
result := ExecuteResult{
Success: !resp.Running && exitCode == 0,
Output: output,
ExitCode: exitCode,
Truncated: resp.Truncated,
}
if resp.Running {
// Process is still running — success is not
// yet determined.
result.Success = true
result.Note = "process is still running"
}
data, err := json.Marshal(result)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return fantasy.NewTextResponse(string(data)), nil
},
)
}
// ProcessList returns an AgentTool that lists all tracked
// processes on the workspace agent.
func ProcessList(options ProcessToolOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"process_list",
"List all tracked processes in the workspace. "+
"Returns process IDs, commands, status (running or "+
"exited), and exit codes. Use this to discover "+
"background processes or check which processes are "+
"still running.",
func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
resp, err := conn.ListProcesses(ctx)
if err != nil {
return errorResult(fmt.Sprintf("list processes: %v", err)), nil
}
data, err := json.Marshal(resp)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return fantasy.NewTextResponse(string(data)), nil
},
)
}
// ProcessSignalArgs are the parameters accepted by the
// process_signal tool.
type ProcessSignalArgs struct {
ProcessID string `json:"process_id"`
Signal string `json:"signal"`
}
// ProcessSignal returns an AgentTool that sends a signal to a
// tracked process on the workspace agent.
func ProcessSignal(options ProcessToolOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"process_signal",
"Send a signal to a background process. "+
"Use \"terminate\" (SIGTERM) for graceful shutdown "+
"or \"kill\" (SIGKILL) to force stop. Use the "+
"process_id returned by execute with "+
"run_in_background=true or from process_list.",
func(ctx context.Context, args ProcessSignalArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
if args.ProcessID == "" {
return fantasy.NewTextErrorResponse("process_id is required"), nil
}
if args.Signal != "terminate" && args.Signal != "kill" {
return fantasy.NewTextErrorResponse(
"signal must be \"terminate\" or \"kill\"",
), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
if err := conn.SignalProcess(ctx, args.ProcessID, args.Signal); err != nil {
return errorResult(fmt.Sprintf("signal process: %v", err)), nil
}
data, err := json.Marshal(map[string]any{
"success": true,
"message": fmt.Sprintf(
"signal %q sent to process %s",
args.Signal, args.ProcessID,
),
})
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return fantasy.NewTextResponse(string(data)), nil
},
)
}
+148
View File
@@ -0,0 +1,148 @@
package chattool
import (
"context"
"database/sql"
"sort"
"strings"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
)
const listTemplatesPageSize = 10
// ListTemplatesOptions configures the list_templates tool.
type ListTemplatesOptions struct {
DB database.Store
OwnerID uuid.UUID
}
type listTemplatesArgs struct {
Query string `json:"query,omitempty"`
Page int `json:"page,omitempty"`
}
// ListTemplates returns a tool that lists available workspace templates.
// The agent uses this to discover templates before creating a workspace.
// Results are ordered by number of active developers (most popular first)
// and paginated at 10 per page.
func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"list_templates",
"List available workspace templates. Optionally filter by a "+
"search query matching template name or description. "+
"Use this to find a template before creating a workspace. "+
"Results are ordered by number of active developers (most popular first). "+
"Returns 10 per page. Use the page parameter to paginate through results.",
func(ctx context.Context, args listTemplatesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.DB == nil {
return fantasy.NewTextErrorResponse("database is not configured"), nil
}
ctx, err := asOwner(ctx, options.DB, options.OwnerID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
filterParams := database.GetTemplatesWithFilterParams{
Deleted: false,
Deprecated: sql.NullBool{
Bool: false,
Valid: true,
},
}
query := strings.TrimSpace(args.Query)
if query != "" {
filterParams.FuzzyName = query
}
templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
// Look up active developer counts so we can sort by popularity.
templateIDs := make([]uuid.UUID, len(templates))
for i, t := range templates {
templateIDs[i] = t.ID
}
ownerCounts := make(map[uuid.UUID]int64)
if len(templateIDs) > 0 {
rows, countErr := options.DB.GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx, templateIDs)
if countErr == nil {
for _, row := range rows {
ownerCounts[row.TemplateID] = row.UniqueOwnersSum
}
}
}
// Sort by active developer count descending.
sort.SliceStable(templates, func(i, j int) bool {
return ownerCounts[templates[i].ID] > ownerCounts[templates[j].ID]
})
// Paginate.
page := args.Page
if page < 1 {
page = 1
}
totalCount := len(templates)
totalPages := (totalCount + listTemplatesPageSize - 1) / listTemplatesPageSize
if totalPages == 0 {
totalPages = 1
}
start := (page - 1) * listTemplatesPageSize
end := start + listTemplatesPageSize
if start > totalCount {
start = totalCount
}
if end > totalCount {
end = totalCount
}
pageTemplates := templates[start:end]
items := make([]map[string]any, 0, len(pageTemplates))
for _, t := range pageTemplates {
item := map[string]any{
"id": t.ID.String(),
"name": t.Name,
}
if display := strings.TrimSpace(t.DisplayName); display != "" {
item["display_name"] = display
}
if desc := strings.TrimSpace(t.Description); desc != "" {
item["description"] = truncateRunes(desc, 200)
}
if count, ok := ownerCounts[t.ID]; ok && count > 0 {
item["active_developers"] = count
}
items = append(items, item)
}
return toolResponse(map[string]any{
"templates": items,
"count": len(items),
"page": page,
"total_pages": totalPages,
"total_count": totalCount,
}), nil
},
)
}
// asOwner sets up a dbauthz context for the given owner so that
// subsequent database calls are scoped to what that user can access.
func asOwner(ctx context.Context, db database.Store, ownerID uuid.UUID) (context.Context, error) {
actor, _, err := httpmw.UserRBACSubject(ctx, db, ownerID, rbac.ScopeAll)
if err != nil {
return ctx, xerrors.Errorf("load user authorization: %w", err)
}
return dbauthz.As(ctx, actor), nil
}

Some files were not shown because too many files have changed in this diff Show More