Compare commits

..

48 Commits

Author SHA1 Message Date
Ethan Dickson de1a317890 refactor(coderd/database/pubsub): clean up batching implementation
- Trim metrics from 18 to 6, removing counters that duplicated info
  available from existing metrics or provided little operational signal.
- Inline single-use constants.
- Re-add hidden PubsubFlushInterval and PubsubQueueSize config knobs
  to ChatConfig for tuning without code changes.
- Remove dead ErrBatchingPubsubQueueFull export (never returned).
- Remove unreachable nil branch in batchFlushStage.
- Propagate resetErr through flushBatch even when delegate replay
  succeeds, so drain reports the broken sender state.
- Annotate sender field with goroutine-safety invariant.
- Regenerate metrics docs.
2026-04-09 14:43:59 +00:00
Ethan Dickson d64ee2e1cc chore(codersdk): remove pubsub batching config knobs 2026-04-09 05:17:48 +00:00
Ethan Dickson 13281d8235 feat(coderd/database/pubsub): add batched pubsub with flush-failure fallback and sender reset
Adds a chatd-specific BatchingPubsub that routes publishes through a
dedicated single-connection sender, coalescing notifications into
single transactions on a 50ms timer. Includes flush-failure fallback
to the shared delegate, automatic sender reset/recreate, expanded
histogram buckets, and focused recovery tests.
2026-04-09 02:19:21 +00:00
Matt Vollmer d954460380 docs: rename "Security implications" to "Security posture" (#24181)
Renames the "Security implications" section to "Security posture" and
reframes the intro paragraph. "Implications" reads as a caveat or
warning; the section actually describes built-in structural guarantees
of the control plane architecture.

> PR generated with Coder Agents
2026-04-08 19:55:56 -04:00
dylanhuff-at-coder f4240bb8c1 fix: sanitize workspace agent logs before insert (#24028)
Workspace agent logs could still fail after the earlier invalid UTF-8
fix because NUL bytes are valid Go/protobuf strings but are rejected by
Postgres text columns. The legacy HTTP log upload path also bypassed the
old sanitization entirely, and both server insert paths computed
logs_length from the unsanitized input.

Add a shared log-output sanitizer in agentsdk, use it in the protobuf
conversion path and both server-side insert paths, and compute
OutputLength from the sanitized string so overflow accounting matches
what is actually stored. This keeps the old invalid UTF-8 behavior while
also handling embedded NUL bytes consistently across DRPC and HTTP log
ingestion.

Refs [#23292 ](https://github.com/coder/coder/issues/23292)
Refs [#13433 ](https://github.com/coder/coder/issues/13433)
2026-04-08 16:29:38 -07:00
Zach 7caef4987f feat: add input validation for user secret env names and file paths (#24103)
Adds backend validation for user secret environment variable names and file paths.

Env name validation enforces POSIX naming rules and blocks a deliberately aggressive denylist of reserved names and prefixes. The denylist errs on the side of blocking too much since it's easier to remove entries later than to add them after users have created conflicting secrets.

File path validation requires paths to start with ~/ or /.
2026-04-08 17:02:33 -06:00
Zach 9b91af8ab7 feat: add user secrets SDK types and db2sdk converters (#24102)
Adds the SDK types and database-to-SDK conversion helpers for the user secrets feature.
2026-04-08 16:48:41 -06:00
Matt Vollmer 506fba9ebf docs: add BYOK docs, fix tool tables, add platform controls (#24178)
Fixes several documentation gaps and inaccuracies in the Coder Agents
docs identified during a deep review against the current product state.

## BYOK (User API Keys)

`models.md` stated *"Developers cannot add their own providers, models,
or API keys"* — this has been incorrect since the provider key policy
system shipped (Apr 2, #23751/#23781).

- Added **Key policy** section documenting the three admin toggles
(`central_api_key_enabled`, `allow_user_api_key`,
`allow_central_api_key_fallback`) with a truth table showing all
resolution outcomes
- Added **User API keys (BYOK)** section covering the developer-facing
key management page, status indicators, selection priority, and key
removal
- Updated `platform-controls/index.md` to reference BYOK instead of
claiming keys are admin-only

## Reasoning effort enum fixes

- **OpenAI**: removed `none` — code accepts `minimal, low, medium, high,
xhigh`
- **OpenRouter**: narrowed to `low, medium, high` per
`ReasoningEffortFromChat` in `chatprovider.go`

## Tool table completeness

- Added `spawn_computer_use_agent`, `read_skill`, `read_skill_file` to
`index.md` tool table
- Added "Workspace extension tools" section to `architecture.md` for
`read_skill`/`read_skill_file`
- Fixed orchestration restriction note to list all 5 gated tools instead
of just `spawn_agent`
- Added conditional availability notes for desktop and skills tools

## Platform controls

Three admin-only settings existed in the Behavior tab with no
documentation:

- **Virtual desktop** — admin toggle, Anthropic + portabledesktop
requirements
- **Workspace autostop fallback** — default TTL for agent workspaces
without template-defined autostop
- **Data retention** — moved `chat-retention.md` into
`platform-controls/` since it's admin-only, fixed nav path

---

> PR generated with Coder Agents
2026-04-08 18:24:12 -04:00
Cian Johnston 461a31e5d8 feat(site): add under-construction navbar stripes for pre-release builds (#24157)
Dev and RC builds now show diagonal warning stripes in the navbar plus a
centered version badge, making it impossible to miss which build you're
running.

**Devel build:** amber "warning" from theme

**RC build:** sky "pending" from theme

> 🤖 Written by a Coder Agent. Will be reviewed by a human.
2026-04-08 20:10:03 +00:00
Carlo Field e3a0dcd6fc feat: add httproute for K8s Gateway API (#23501)
<!--

If you have used AI to produce some or all of this PR, please ensure you
have read our [AI Contribution
guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING)
before submitting.

-->
No AI was used to generate this PR.

Adds support for [Gateway API
HTTPRoutes](https://gateway-api.sigs.k8s.io/api-types/httproute/) as an
alternative to Ingress.

---------

Signed-off-by: Carlo Field <carlo@swiss.dev>
Co-authored-by: bpmct <bpmct@users.noreply.github.com>
Co-authored-by: Ben Potter <ben@coder.com>
2026-04-08 14:59:17 -05:00
Danielle Maywood 12ada0115f fix(site): move pagination test from vitest to storybook story (#24165) 2026-04-08 20:56:53 +01:00
Cian Johnston 7b0421d8c6 fix: revert auto-assign agents-access role enabled (#24170)
This reverts commit d4a9c63e91 (#23968).

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-08 20:56:17 +01:00
Hugo Dutka 477d6d0cde fix(site): fix agents right panel layout on small landscape viewports (#24161)
Currently, when you're using Agents on mobile with a vertical viewport
and you open the sidebar, the sidebar takes up the entire screen. That's
great, since there isn't enough space to show the other tabs. But when
you tilt your phone to horizontal mode, all 3 tabs show up, and none of
them are very legible:


https://github.com/user-attachments/assets/50a54791-fe53-4a5d-ba7b-85e82f970851

This PR makes it so that the right sidebar takes up the entire screen on
small viewports (<1024px) in horizontal mode too.



https://github.com/user-attachments/assets/a06069df-9f2f-42bd-8072-a237434434e5
2026-04-08 20:01:59 +02:00
Jeremy Ruppel de61ac529d fix(site): scroll when request logs tool call is huge (#24162)
**Disclaimer: I've never encountered this on dogfood, only on my local
where Claude likes to do really long tool calls**

On the Request Logs page, if a tool call has super long lines, it will
break the row layout:


https://github.com/user-attachments/assets/fd1a8be0-7912-4611-a1c3-0c7943b1ea52

This adds stories to demonstrate the behavior, and then a lil overflow x
auto action for the fix


https://github.com/user-attachments/assets/f0fd94da-8254-4330-a718-08599909e8ec
2026-04-08 13:53:43 -04:00
Yevhenii Shcherbina 7f496c2f18 feat: byok-observability for aibridge (#23808)
## Summary

Adds `credential_kind` and `credential_hint` columns to
`aibridge_interceptions` to record how each LLM request was
authenticated and provide a masked credential identifier for audit
purposes.

This enables admins to distinguish between centralized API keys,
personal API keys, and subscription-based credentials in the
interceptions audit log.

## Changes

- New migration adding `credential_kind`and `credential_hint` to
`aibridge_interceptions`
- Updated `InsertAIBridgeInterception` query and proto definition to
carry the new fields
- Wired proto fields through `translator.go` and `aibridgedserver.go` to
the database

Depends on https://github.com/coder/aibridge/pull/239
2026-04-08 13:24:28 -04:00
Michael Suchacz 590235138f fix: pin fixed anthropic/fantasy forks for streaming token accounting (#24077) 2026-04-08 17:07:39 +00:00
blinkagent[bot] 543c448b72 docs: update release calendar to reflect 2.31 as stable (#24159)
Update the release calendar table now that v2.31.7 has been promoted to
stable (`latest` on GitHub Releases).

## Changes

| Release | Old Status | New Status | Latest Patch |
|---------|-----------|------------|-------------|
| 2.31 | Mainline | Stable | v2.31.7 |
| 2.30 | Stable | Security Support | v2.30.6 |
| 2.29 | Security Support + ESR | Extended Support Release | v2.29.9 |

---

> **Note:** The auto-generation script
(`scripts/update-release-calendar.sh`) determines status positionally
from the latest non-RC tag, so it will always mark the latest minor
version as "Mainline". This manual update is needed to reflect the
promotion of 2.31 to stable.

Co-authored-by: blink-so[bot] <211532188+blink-so[bot]@users.noreply.github.com>
2026-04-08 17:02:07 +00:00
Kyle Carberry 35c26ce22a feat: add CreatedAt to tool-call and tool-result ChatMessageParts (#24101)
Adds an optional `CreatedAt` timestamp to `tool-call` and `tool-result`
`ChatMessagePart` variants so the frontend can compute tool execution
duration (`result.created_at - call.created_at`).

Timestamps are recorded at the correct moments in the chatloop:
- **Tool-call**: when the model stream emits the tool call
- **Tool-result**: when tool execution completes (or is interrupted)

These are passed through `PersistedStep.PartCreatedAt` so the
persistence layer can apply accurate timestamps to stored parts.
SSE-published parts also carry `CreatedAt` for real-time display.

Old persisted messages without `created_at` deserialize to `nil` — fully
backward compatible.

<details><summary>Implementation notes (Coder Agents
generated)</summary>

### Why not stamp in `PartFromContent`?

`PartFromContent` is called both for SSE publishing (correct timing) and
during persistence (wrong timing — both tool-call and tool-result would
get the same "persistence time" timestamp, yielding ~0 duration).
Instead, timestamps are captured in the chatloop at the right moments
and carried through `PersistedStep.PartCreatedAt` as a
`map[string]time.Time` keyed by `"call:<id>"` / `"result:<id>"`.

### Interrupted tool calls

`persistInterruptedStep` also stamps `CreatedAt` on synthetic error
results for cancelled/interrupted tool calls, so partial duration is
available.

### Files changed

| File | Change |
|------|--------|
| `codersdk/chats.go` | Add `CreatedAt *time.Time` field |
| `codersdk/chats_test.go` | JSON round-trip test |
| `coderd/database/dbtime/dbtime.go` | Add `TimePtr` helper |
| `coderd/x/chatd/chatloop/chatloop.go` | Track timestamps, pass through
`PersistedStep` |
| `coderd/x/chatd/chatd.go` | Apply timestamps during persistence |
| `coderd/x/chatd/chatprompt/chatprompt_test.go` | Verify
`PartFromContent` does NOT stamp |
| `site/src/api/typesGenerated.ts` | Auto-generated |

</details>

---------

Co-authored-by: Ethan <39577870+ethanndickson@users.noreply.github.com>
2026-04-08 12:42:03 -04:00
Jiachen Jiang c2592c9f12 docs: add AI Bridge structured log record types and monitoring cross-link (#23979)
## What

Two small docs improvements for AI Bridge:

1. **`setup.md` – Structured Logging section**: Added a `record_type`
table documenting the six event types emitted by AI Bridge structured
logs (`interception_start`, `interception_end`, `token_usage`,
`prompt_usage`, `tool_usage`, `model_thought`) along with their key
fields. Previously only the `"interception log"` message prefix was
mentioned.

2. **`monitoring.md`**: Added a "Structured Logging" section that
cross-links to `setup.md#structured-logging`, so users landing on the
monitoring page can discover the feature without navigating to the setup
guide first.

<details><summary>Source reference</summary>

Record types and fields were extracted from
`enterprise/aibridgedserver/aibridgedserver.go` where they are emitted
as `slog.F("record_type", "...")` string literals under the
`InterceptionLogMarker` (`"interception log"`) message.

</details>
2026-04-08 08:57:17 -07:00
Kyle Carberry b969d66978 feat: add dynamic tools support for chat API (#24036)
Adds client-executed dynamic tools to the chat API. Dynamic tools are
declared by the client at chat creation time, presented to the LLM
alongside built-in tools, but executed by the client rather than chatd.
This enables external systems (Slack bots, IDE extensions, Discord bots,
CI/CD integrations) to plug custom tools into the LLM chat loop without
modifying chatd's built-in tool set.

Modeled after OpenAI's Assistants API: the chat pauses with
`requires_action` status when the LLM calls a dynamic tool, the client
POSTs results back via `POST /chats/{id}/tool-results`, and the chat
resumes.

See [this example](https://github.com/coder/coder-slackbot-poc) as a
reference for how this is used. It's highly-configurable, which would
enable creating chats from webhooks, periodically polling, or running as
a Slackbot.

<details>
<summary>Design context</summary>

### Architecture

The chatloop **exits** when it encounters dynamic tools and
**re-enters** when results arrive. No blocking channels, no pubsub for
tool results, no in-memory registry. The DB is the only coordination
mechanism.

```
Phase 1 (chatloop):
  LLM response → execute built-in tools only →
  Persist(assistant + built-in results) →
  status = requires_action → chatloop exits

Phase 2 (POST /tool-results):
  Persist(dynamic tool results) →
  status = pending → wakeCh → chatloop re-enters
```

### Validation (POST /tool-results)

1. Chat status must be `requires_action` (409 if not)
2. Read chat's `dynamic_tools` → set of dynamic tool names
3. Read last assistant message → extract tool-call parts matching
dynamic tool names
4. Submitted tool_call_ids must match exactly (400 for missing/extra)
5. Persist tool-result message parts, set status to `pending`, signal
wake

### Idempotency

Tool call IDs scoped per LLM step. State machine (`requires_action` →
`pending`) is the guard. First POST wins, subsequent get 409.

### Mixed tool calls

When the LLM calls both built-in and dynamic tools in one step, built-in
tools execute immediately. Their results are persisted in phase 1.
Dynamic tool results arrive via POST in phase 2. The LLM sees all
results when the chatloop resumes.

</details>

> 🤖 Generated by Coder Agents
2026-04-08 11:54:44 -04:00
Jaayden Halko 1f808cdc62 fix(site): standardize scrollbar styling with global baseline (#24019)
## Summary

Standardizes all frontend scrollbars to use `scrollbar-width: thin` and
`scrollbar-color: hsl(var(--surface-quaternary)) transparent`.

### Changes

**Global baseline** (`site/src/index.css`):
- Both properties are inherited, so this cascades to all scroll
containers
- Components that hide scrollbars (e.g. `SidebarTabView`) override
locally with `scrollbar-width: none`

**Removed redundant per-component scrollbar utilities**:
- `AgentDetailView.tsx` — removed `[scrollbar-width:thin]` and
`[scrollbar-color:...]` (preserved `[scrollbar-gutter:stable]`)
- `ConfigureAgentsDialog.tsx` — removed redundant scrollbar utilities
from two locations
- `DeploymentBannerView.tsx` — removed `[scrollbar-width:thin]`
- `ChatMessageInput.tsx` — removed redundant scrollbar utilities

**Aligned specialized scrollbar surfaces**:
- `TerminalPage.tsx` — updated webkit scrollbar thumb from hardcoded
`rgba(255, 255, 255, 0.18)` → `hsl(var(--surface-quaternary))`, track
from `inherit` → `transparent`, width from `10px` → `8px`
- `Chart.tsx` — removed local JS-style scrollbar overrides (now covered
by global baseline)

### Preserved as-is
- `SidebarTabView.tsx` — intentional hidden scrollbar (`scrollbar-width:
none` overrides global)
- `ScrollArea.tsx` — already uses `bg-surface-quaternary` ✓
- `MonacoEditor.tsx` — Monaco manages its own scrollbars internally
- All `[scrollbar-gutter:stable]` usages preserved
2026-04-08 16:41:23 +01:00
Cian Johnston 497f637f58 chore: revert force deploying main (#23290) (#24072)
⚠️ DO NOT MERGE UNTIL @f0ssel SAYS SO ⚠️ 

This reverts commit 8f78c5145f
(https://github.com/coder/coder/pull/23290).
2026-04-08 11:19:14 -04:00
Ethan be686a8d0d fix(scripts/githooks): clear all repo-local Git env vars in hooks (#24138)
## Problem

In linked worktrees, Git hooks inherit multiple repo-local environment
variables: `GIT_DIR`, `GIT_COMMON_DIR`, `GIT_INDEX_FILE`, and others.
The
pre-commit and pre-push hooks only unset `GIT_DIR`, leaving the rest in
place.

When `make pre-commit` runs `go build`, Go tries to stamp VCS info by
shelling
out to `git`. With the leftover partial Git environment, `git` exits 128
and
the build fails:

```
error obtaining VCS status: exit status 128
    Use -buildvcs=false to disable VCS stamping.
```

This only happens inside hooks in a linked worktree — running `make
pre-commit`
directly from the terminal works fine because the repo-local vars are
not set.

## Fix

Replace the bare `unset GIT_DIR` in both hooks with a loop that clears
every
variable reported by `git rev-parse --local-env-vars`:

```sh
while IFS= read -r var; do
    unset "$var"
done < <(git rev-parse --local-env-vars)
```

This covers all 15 repo-local variables Git may inject (`GIT_DIR`,
`GIT_COMMON_DIR`, `GIT_INDEX_FILE`, `GIT_OBJECT_DIRECTORY`, etc.) and is
forward-compatible — if Git adds new local vars in the future, the loop
picks
them up automatically.
2026-04-09 01:06:12 +10:00
Garrett Delfosse 7b7baea851 feat: support disabling reverse/local port forwarding in agent SSH server (#24026)
The agent SSH server unconditionally allows all four SSH forwarding
paths (TCP local, TCP reverse, Unix local, Unix reverse). This is a
sandbox escape vector when workspaces are used for AI agent containment
— a reverse tunnel lets anything inside the workspace reach the user's
local machine, bypassing network isolation.

This adds two new agent CLI flags / environment variables:

- `--block-reverse-port-forwarding` /
`CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING` — blocks both TCP (`ssh -R`)
and Unix socket reverse forwarding
- `--block-local-port-forwarding` /
`CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING` — blocks both TCP (`ssh -L`)
and Unix socket local forwarding

Template admins can set these via the `env` block on the container/VM
resource that runs the agent (e.g. `docker_container`,
`kubernetes_pod`), or via `coder_env` resources tied to the agent.

Fixes https://github.com/coder/coder/issues/22275

<details>
<summary>Implementation notes</summary>

Follows the existing `BlockFileTransfer` pattern:

1. `agent/agentssh/agentssh.go` — New `BlockReversePortForwarding` and
`BlockLocalPortForwarding` fields on `Config`. TCP callbacks check these
before allowing forwarding. The `direct-streamlocal@openssh.com` channel
handler is wrapped to reject Unix local forwards.
2. `agent/agentssh/forward.go` — `forwardedUnixHandler` gains a
`blockReversePortForwarding` field to reject
`streamlocal-forward@openssh.com` requests.
3. `agent/agent.go` — New fields on `Options` and `agent` struct,
plumbed to SSH config.
4. `cli/agent.go` — New serpent flags with env vars.
5. Tests cover all four blocked paths: TCP local, TCP reverse, Unix
local, Unix reverse.

</details>

> 🤖 Generated by Coder Agents
2026-04-08 10:41:55 -04:00
Garrett Delfosse a3de0fc78d ci: add automatic backport workflow (#24025)
Adds a GitHub Actions workflow that automatically cherry-picks merged
PRs to the last 3 release branches when the `backport` label is applied.

## How it works

1. Add the `backport` label to any PR targeting `main` (before or after
merge).
2. On merge (or on label if already merged), the workflow discovers the
latest 3 `release/*` branches by semver.
3. For each branch, it cherry-picks the merge commit (`-x -m1`) and
opens a PR.

Created backport PRs follow existing repo conventions:
- **Branch:** `backport/<pr>-to-<version>`
- **Title:** `<original PR title> (#<pr>)` — e.g. `fix(site): correct
button alignment (#12345)`
- **Body:** links back to the original PR and merge commit

If cherry-pick has conflicts, the PR is still opened with instructions
for manual resolution — no conflict markers are committed.

Also:
- Removes `scripts/backport-pr.sh` (replaced by this workflow)
- Removes `.github/cherry-pick-bot.yml` (old bot config)
- Adds a section to the contributing docs explaining how to use the
backport label

> [!NOTE]
> Generated with [Coder Agents](https://coder.com/agents)
2026-04-08 14:30:48 +00:00
Garrett Delfosse ab77154975 ci: add cherry-pick to latest release workflow (#24051)
Adds a GitHub Actions workflow that cherry-picks merged PRs to the
latest release branch when the `cherry-pick` label is applied.

## How it works

1. Add the `cherry-pick` label to any PR targeting `main` (before or
after merge).
2. On merge (or on label if already merged), the workflow detects the
latest `release/*` branch.
3. It cherry-picks the merge commit (`-x -m1`) and opens a PR.

This complements the `backport` label (see #24025) which targets the
latest **3** release branches. `cherry-pick` targets only the **latest**
one — useful for getting fixes into the current release.

Created PRs follow existing repo conventions:
- **Branch:** `backport/<pr>-to-<version>`
- **Title:** `<original PR title> (#<pr>)` — e.g. `fix(site): correct
button alignment (#12345)`
- **Body:** links back to the original PR and merge commit

If the cherry-pick encounters conflicts, the workflow aborts the
cherry-pick, creates an empty commit with resolution instructions, and
opens the PR with a `[CONFLICT]` prefix so the author can resolve
manually.

Also:
- Removes `scripts/backport-pr.sh` (replaced by this workflow)
- Removes `.github/cherry-pick-bot.yml` (old bot config)
- Adds a section to the contributing docs explaining the `cherry-pick`
label

> [!NOTE]
> Generated with [Coder Agents](https://coder.com/agents)
2026-04-08 10:22:33 -04:00
Kyle Carberry c5d720f73d feat(coderd): add telemetry for agents chats and messages (#24068)
Adds telemetry collection for the agents chat system (`/agents`) to the
existing telemetry snapshot pipeline.

Three new snapshot fields:
- **`Chats`** — per-chat metadata (id, owner, status, mode,
workspace_id, root_chat_id, has_parent, archived, model config)
collected time-windowed via `createdAfter`
- **`ChatMessageSummaries`** — per-chat aggregated message metrics
(counts by role, token sums by type, cost, runtime, model count,
compression count) collected time-windowed
- **`ChatModelConfigs`** — model configuration metadata (provider,
model, context limit, enabled, default) collected as full dump

No PII is included — titles, message content, and URLs are excluded at
the SQL level. Only structural metadata flows through telemetry.

<details><summary>Implementation plan</summary>

### SQL Queries (`coderd/database/queries/chats.sql`)
- `GetChatsCreatedAfter` — time-windowed chat metadata
- `GetChatMessageSummariesPerChat` — per-chat message aggregates via
`GROUP BY`
- `GetChatModelConfigsForTelemetry` — full dump of model configs

### Telemetry (`coderd/telemetry/telemetry.go`)
- `Chat`, `ChatMessageSummary`, `ChatModelConfig` structs
- `ConvertChat`, `ConvertChatMessageSummary`, `ConvertChatModelConfig`
conversion functions
- Three `eg.Go()` blocks in `createSnapshot()` following the existing
collection pattern

### Authorization (`coderd/database/dbauthz/dbauthz.go`)
- System-only access for all three queries via `rbac.ResourceSystem`

### Tests
- `TestChatsTelemetry` in `coderd/telemetry/telemetry_test.go` — creates
chats (root + child), messages with token/cost data, model configs;
verifies all snapshot fields
- dbauthz test entries for all three queries in
`coderd/database/dbauthz/dbauthz_test.go`

</details>

> 🤖 Generated by Coder Agents
2026-04-08 09:47:44 -04:00
Atif Ali 983819860f docs: replace dockerd with service docker start in Sysbox examples (#24004)
## Problem

The Sysbox docker-in-workspaces docs examples use `sudo dockerd &` in
`startup_script` to start Docker. This causes workspaces to report as
unhealthy because `dockerd` keeps references to stdout/stderr after the
script exits.

## Fix

Replace `sudo dockerd &` with `sudo service docker start`, which
properly daemonizes Docker through the service manager and returns
cleanly. This matches the pattern used in our [dogfood
template](https://github.com/coder/coder/blob/main/dogfood/coder/main.tf#L614).

## Validation

Created a test template and workspace on dogfood — agent reported `✔
healthy` and `docker info` confirmed the daemon running inside the
workspace.

Fixes #21166

> 🤖 This PR was created with the help of Coder Agents, and has been
reviewed by my human. 🧑💻
2026-04-08 13:03:18 +00:00
Cian Johnston f820945d9f refactor: decompose AgentSettingsBehaviorPageView + remove kyleosophy (#24141)
- Remove Kyleosophy alternative completion chimes (keeps original chime
intact)
- Extract 5 sub-components from the 717-line god component:
  - `PersonalInstructionsSettings` — user prompt textarea form
- `SystemInstructionsSettings` — admin system prompt + TextPreviewDialog
  - `VirtualDesktopSettings` — admin desktop toggle
  - `WorkspaceAutostopSettings` — admin autostop toggle + duration form
  - `RetentionPeriodSettings` — admin retention toggle + number input
- Parent is now a ~160-line layout shell
- `isAnyPromptSaving` coupling preserved via prop
- Add `docs/plans/` to `.gitignore`

> 🤖 Written by a Coder Agent. Reviewed by a human.
2026-04-08 14:01:38 +01:00
Hugo Dutka da5395a8ae feat(site): take/release control agents desktop buttons (#24009)
Add "Take control" and "Release control" buttons to the agents desktop
sidebar. This prevents accidental inputs in the VNC window.


https://github.com/user-attachments/assets/b5319579-e1c5-433b-9ba5-b239661a2e4c
2026-04-08 12:53:42 +02:00
Danielle Maywood 86b919e4f7 refactor: replace useEffectEvent polyfill with native React 19.2 hook (#24060) 2026-04-08 11:17:11 +01:00
Cian Johnston 233343c010 feat: add chat and chat_files cleanup to dbpurge (#23833)
Fixes https://github.com/coder/coder/issues/23910

Adds periodic cleanup of chats and chat files to the dbpurge background
goroutine, with a configurable retention period exposed in the Agent
settings UI.

> 🤖 Written by a Coder Agent. Reviewed by a human.
2026-04-08 11:08:09 +01:00
Danielle Maywood 3a612898c6 refactor(site/src/pages/AgentsPage): extract ConfirmDeleteDialog component (#24128) 2026-04-08 11:07:39 +01:00
Danielle Maywood 3f7a3e3354 perf: reorder declarations to fix React Compiler scope pruning (#24098) 2026-04-08 09:40:41 +01:00
Danielle Maywood 17a71aea72 refactor(site/src/pages/AgentsPage): extract BackButton and AdminBadge (#24130) 2026-04-08 09:32:40 +01:00
Jeremy Ruppel 7d3c5ac78c fix(site): inline dl/dt/dd classNames and use justify-between layout in session tables (#24118)
When we refactored into definition lists for tables, we lost the ability
to have the rows extend beyond the vertical line between `<dt>` and
`<dd>`

This adds a wrapping `<div>` to make each row independent, which is
[a-ok per
MDN](https://developer.mozilla.org/en-US/docs/Web/HTML/Reference/Elements/dl#wrapping_name-value_groups_in_div_elements),
an also is implied in the Figma:
<img width="477" height="182" alt="Screenshot 2026-04-07 at 4 29 14 PM"
src="https://github.com/user-attachments/assets/524acfc3-c614-479e-9a13-36107c158ee8"
/>

---

Before 
<img width="420" height="266" alt="Screenshot 2026-04-07 at 4 24 22 PM"
src="https://github.com/user-attachments/assets/7001c17c-05da-4f90-b6d4-a9c6cab695cb"
/>

After
<img width="410" height="355" alt="Screenshot 2026-04-07 at 4 24 36 PM"
src="https://github.com/user-attachments/assets/3d1d278d-0080-44be-8d32-bb5dff879969"
/>
2026-04-08 16:17:39 +10:00
dependabot[bot] d87c5ef439 chore: bump github.com/aws/aws-sdk-go-v2/service/s3 from 1.96.0 to 1.97.3 (#24136)
Bumps
[github.com/aws/aws-sdk-go-v2/service/s3](https://github.com/aws/aws-sdk-go-v2)
from 1.96.0 to 1.97.3.
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/90650dd22735ab68f6089ae5c39b6614286ae9ec"><code>90650dd</code></a>
Release 2026-03-26</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/dd88818bee7d632a8b9da6e2c78ef92e23c94c62"><code>dd88818</code></a>
Regenerated Clients</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/b662c50138bd393927871b46e84ee3483377f5be"><code>b662c50</code></a>
Update endpoints model</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/500a9cb3522a0e71d798d7079ff5856b23c2cac1"><code>500a9cb</code></a>
Update API model</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/6221102f763bd65d7e403fa62c3a1e3d39e24dc6"><code>6221102</code></a>
fix stale skew and delayed skew healing (<a
href="https://redirect.github.com/aws/aws-sdk-go-v2/issues/3359">#3359</a>)</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/0a39373433a121800bc68efa743a7486eb07aa3f"><code>0a39373</code></a>
fix order of generated event header handlers (<a
href="https://redirect.github.com/aws/aws-sdk-go-v2/issues/3361">#3361</a>)</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/098f3898271e2eaaf8a92e38d1d928fb018805a6"><code>098f389</code></a>
Only generate resolveAccountID when it's required (<a
href="https://redirect.github.com/aws/aws-sdk-go-v2/issues/3360">#3360</a>)</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/6ebab66428e97db0ee252fea042d56b1313cb9f6"><code>6ebab66</code></a>
Release 2026-03-25</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/b2ec3beebb986a5e74e50d0c105119d84e1e934e"><code>b2ec3be</code></a>
Regenerated Clients</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/abc126f6b35bfe2f77e2505f6d04f8ceced971ee"><code>abc126f</code></a>
Update API model</li>
<li>Additional commits viewable in <a
href="https://github.com/aws/aws-sdk-go-v2/compare/service/s3/v1.96.0...service/s3/v1.97.3">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/aws/aws-sdk-go-v2/service/s3&package-manager=go_modules&previous-version=1.96.0&new-version=1.97.3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
You can disable automated security fix PRs for this repo from the
[Security Alerts page](https://github.com/coder/coder/network/alerts).

</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-08 04:40:17 +00:00
dependabot[bot] ef3e17317c chore: bump github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream from 1.7.6 to 1.7.8 (#24134)
Bumps
[github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream](https://github.com/aws/aws-sdk-go-v2)
from 1.7.6 to 1.7.8.
<details>
<summary>Commits</summary>
<ul>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/e3b97d2a02cd4e27c40224f05aa1a7deba24abe2"><code>e3b97d2</code></a>
Release 2023-10-12</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/863010ddb23c242c2a5d49d9f40094a6a49b5525"><code>863010d</code></a>
Regenerated Clients</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/6946ef8b9149fe75ac1b427ca2c7f57cdcb64549"><code>6946ef8</code></a>
Update endpoints model</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/6d93ded4536184d38a664b4b75dadd36cbd79878"><code>6d93ded</code></a>
Update API model</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/bebc232e7f65b02d0b519d11e73cf925c38e716f"><code>bebc232</code></a>
fix: fail to load config if configured profile doesn't exist (<a
href="https://redirect.github.com/aws/aws-sdk-go-v2/issues/2309">#2309</a>)</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/5de46742b7fb1b72d93d344ee81568800a707267"><code>5de4674</code></a>
fix DNS timeout error not retried (<a
href="https://redirect.github.com/aws/aws-sdk-go-v2/issues/2300">#2300</a>)</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/e155bb72a2ec20ec61db50fc3d4568e373fa4b63"><code>e155bb7</code></a>
Release 2023-10-06</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/9d342ba33937c562d215f317a37dea121ee9763d"><code>9d342ba</code></a>
Regenerated Clients</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/1df99141a143a38570d64a182ed972ce9e3dba65"><code>1df9914</code></a>
Update SDK's smithy-go dependency to v1.15.0</li>
<li><a
href="https://github.com/aws/aws-sdk-go-v2/commit/32ada3a191ac770b1b24164b667692183fc77ed9"><code>32ada3a</code></a>
Update API model</li>
<li>See full diff in <a
href="https://github.com/aws/aws-sdk-go-v2/compare/service/m2/v1.7.6...service/m2/v1.7.8">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream&package-manager=go_modules&previous-version=1.7.6&new-version=1.7.8)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
You can disable automated security fix PRs for this repo from the
[Security Alerts page](https://github.com/coder/coder/network/alerts).

</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-08 03:14:12 +00:00
Kayla はな 1187b84c54 refactor(site): remove mui from icon components (#24117) 2026-04-07 17:32:05 -06:00
Jeremy Ruppel 45336bd9ce fix(site): use field value instead of controlled value in PasswordField (#24123)
`<PasswordField>`'s value should come from the field helpers, not from a
prop
2026-04-07 19:04:29 -04:00
Jeremy Ruppel 36cf7debce fix(site): add resize observer to session timeline expandable text (#24119)
I said I wouldn't but the illustrious @jakehwll added a ResizeObserver
recently so imma do that too.

This makes `<ExpandableText>` determine if it should be expandable or
not on resize
2026-04-07 19:04:05 -04:00
Ehab Younes 027c222e82 fix(cli): add dial timeout and keepalive for Coder Connect (#24015)
The default `net.Dialer` in the Coder Connect path had no timeout,
falling back to the OS TCP timeout when the tunnel was broken but DNS
still resolved. Add a 5s dial timeout and 30s TCP keepalive.

Fixes #24006
2026-04-08 01:11:28 +03:00
Ehab Younes d00f148b76 fix(cli): retry transient connection failures during SSH setup (#24010)
When `coder ssh` connects to a workspace after laptop wake, DNS or the
control plane may be briefly unavailable. Previously this caused an
immediate failure, which VS Code Remote SSH classified as permanent
("Reload Window").

Wrap each network step (workspace resolution, template version fetch,
agent connection info, Coder Connect dial, tailnet dial) with
`retryWithInterval` so transient errors (DNS, connection refused, 5xx)
are retried individually. Non-retryable errors (auth, 404) and context
cancellation stop immediately. Data transfer is never retried.
2026-04-08 00:59:10 +03:00
Garrett Delfosse 48bc215f20 chore: tag RCs on main, cut release branch only for releases (#24001)
RC tags are now created directly on `main`. The `release/X.Y` branch is
only cut when the actual release is ready. This eliminates the need to
cherry-pick hundreds of commits from main onto the release branch
between the first RC and the release.

## Workflow

```
main:  ──●──●──●──●──●──●──●──●──●──
              ↑           ↑     ↑
           rc.0        rc.1    cut release/2.34, tag v2.34.0
                                     \
                               release/2.34:  ──●── v2.34.1 (patch)
```

1. **RC:** On `main`, run `./scripts/release.sh`. The tool detects main
(or a detached HEAD reachable from main), prompts for the commit SHA to
tag, suggests the next RC version, and tags it.
2. **Release:** When the RC is blessed, create `release/X.Y` from `main`
(or the specific RC commit). Switch to that branch and run
`./scripts/release.sh`, which suggests `vX.Y.0`.
3. **Patch:** Cherry-pick fixes onto `release/X.Y` and run
`./scripts/release.sh` from that branch.

## Changes

### `scripts/releaser/release.go`
- Two modes based on branch:
- **`main` (or detached HEAD from main)** — RC tagging. Prompts for the
commit SHA to tag (defaults to HEAD). Always checks out the target
commit so the flow operates in detached HEAD. Suggests the next RC based
on existing RC tags.
- **`release/X.Y`** — Release/patch mode. Suggests `vX.Y.0` if the
latest tag is an RC, or the next patch otherwise.
- Detached HEAD support: if `git branch --show-current` is empty, checks
whether HEAD is an ancestor of `origin/main` and enters RC mode
automatically.
- Commit selection prompt in RC mode: shows current commit, lets the
user confirm or provide a different SHA.
- Warns if you try to tag a non-RC on main, or an RC on a release
branch.
- Skips open-PR check and branch sync check in RC mode (not useful on
main).

### `scripts/releaser/main.go`
- Updated help text.

### `.github/workflows/release.yaml`
- RC tags (`*-rc.*`): skip the release-branch validation (they live on
main).
- Non-RC tags: still require the corresponding `release/X.Y` branch.

### `docs/about/contributing/CONTRIBUTING.md`
- Rewrote the Releases section with the new workflow, release types
table, and ASCII diagram.
- Replaced the old "Creating a release" / "Creating a release (via
workflow dispatch)" subsections.

<details><summary>Decision log</summary>

### Why this approach?

Previously, cutting a release branch early for an RC meant
cherry-picking all of main's progress onto that branch before the actual
release — often hundreds of commits. This approach avoids that entirely:
RCs are just tagged snapshots of main, and the release branch only
exists once you need it for stabilization and backports.

### Files NOT changed

- **`scripts/release/publish.sh`** — `--rc` flag controls GitHub
prerelease marking (tag-level, not branch-level). `target_commitish`
already defaults to `main` when the tag isn't on a release branch.
- **`scripts/release/tag_version.sh`** — No RC-specific branch logic.
- **`scripts/releaser/version.go`** — Version parsing/comparison
unchanged.
- **`docs/install/releases/index.md`** — Public-facing docs describe RC
as a release channel with no branch-level detail.

</details>

> Generated by Coder Agents
2026-04-07 15:21:22 -04:00
Jon Ayers 08bd9e672a fix: resolve Test_batcherFlush/RetriesOnTransientFailure flake (#24112)
fixes https://github.com/coder/internal/issues/1452
2026-04-07 13:46:26 -05:00
Kayla はな c5f1a2fccf feat: make service accounts a Premium feature (#24020) 2026-04-07 12:25:32 -06:00
Jake Howell 655d647d40 fix: resolve style not passing in <LogLine /> (#24111)
This pull-request resolves an regression where the spread was overriding
the required styles from the `react-window` virtualised rows. This was
causing the scroll to act a little crazy.
2026-04-07 17:54:16 +00:00
Kyle Carberry f3f0a2c553 fix(enterprise/coderd/x/chatd): harden TestSubscribeRelayEstablishedMidStream against CI flakes (#24108)
Fixes coder/internal#1455

Three changes to eliminate the timing-sensitive flake in
`TestSubscribeRelayEstablishedMidStream`:

1. **Reduce `PendingChatAcquireInterval` from `time.Hour` to
`time.Second`.**
   The primary trigger is still `signalWake()` from `SendMessage`, but a
   short fallback poll ensures the worker picks up the pending chat
   even under heavy CI goroutine scheduling contention.

2. **Increase context timeout from `WaitLong` (25s) to `WaitSuperLong`
(60s).**
   The worker pipeline (model resolution, message loading, LLM call)
   involves multiple DB round-trips that can be slow when PostgreSQL
   is shared with many parallel test packages.

3. **Add a status-polling loop while waiting for the streaming
request.**
   If the worker errors out during chat processing, the test now
   fails immediately with the error status and message instead of
   silently timing out.

> Generated by Coder Agents
2026-04-07 13:41:33 -04:00
225 changed files with 11941 additions and 2893 deletions
-2
View File
@@ -1,2 +0,0 @@
enabled: true
preservePullRequestTitle: true
+174
View File
@@ -0,0 +1,174 @@
# Automatically backport merged PRs to the last N release branches when the
# "backport" label is applied. Works whether the label is added before or
# after the PR is merged.
#
# Usage:
# 1. Add the "backport" label to a PR targeting main.
# 2. When the PR merges (or if already merged), the workflow detects the
# latest release/* branches and opens one cherry-pick PR per branch.
#
# The created backport PRs follow existing repo conventions:
# - Branch: backport/<pr>-to-<version>
# - Title: <original PR title> (#<pr>)
# - Body: links back to the original PR and merge commit
name: Backport
on:
pull_request_target:
branches:
- main
types:
- closed
- labeled
permissions:
contents: write
pull-requests: write
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
# fire in quick succession.
concurrency:
group: backport-${{ github.event.pull_request.number }}
jobs:
detect:
name: Detect target branches
if: >
github.event.pull_request.merged == true &&
contains(github.event.pull_request.labels.*.name, 'backport')
runs-on: ubuntu-latest
outputs:
branches: ${{ steps.find.outputs.branches }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
# Need all refs to discover release branches.
fetch-depth: 0
- name: Find latest release branches
id: find
run: |
# List remote release branches matching the exact release/2.X
# pattern (no suffixes like release/2.31_hotfix), sort by minor
# version descending, and take the top 3.
BRANCHES=$(
git branch -r \
| grep -E '^\s*origin/release/2\.[0-9]+$' \
| sed 's|.*origin/||' \
| sort -t. -k2 -n -r \
| head -3
)
if [ -z "$BRANCHES" ]; then
echo "No release branches found."
echo "branches=[]" >> "$GITHUB_OUTPUT"
exit 0
fi
# Convert to JSON array for the matrix.
JSON=$(echo "$BRANCHES" | jq -Rnc '[inputs | select(length > 0)]')
echo "branches=$JSON" >> "$GITHUB_OUTPUT"
echo "Will backport to: $JSON"
backport:
name: "Backport to ${{ matrix.branch }}"
needs: detect
if: needs.detect.outputs.branches != '[]'
runs-on: ubuntu-latest
strategy:
matrix:
branch: ${{ fromJson(needs.detect.outputs.branches) }}
fail-fast: false
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
PR_TITLE: ${{ github.event.pull_request.title }}
PR_URL: ${{ github.event.pull_request.html_url }}
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
# Full history required for cherry-pick.
fetch-depth: 0
- name: Cherry-pick and open PR
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
RELEASE_VERSION="${{ matrix.branch }}"
# Strip the release/ prefix for naming.
VERSION="${RELEASE_VERSION#release/}"
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
# Check if backport branch already exists (idempotency for re-runs).
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
echo "Backport branch ${BACKPORT_BRANCH} already exists, skipping."
exit 0
fi
# Create the backport branch from the target release branch.
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_VERSION}"
# Cherry-pick the merge commit. Use -x to record provenance and
# -m1 to pick the first parent (the main branch side).
CONFLICTS=false
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
echo "::warning::Cherry-pick to ${RELEASE_VERSION} had conflicts."
CONFLICTS=true
# Abort the failed cherry-pick and create an empty commit
# explaining the situation.
git cherry-pick --abort
git commit --allow-empty -m "Cherry-pick of #${PR_NUMBER} requires manual resolution
The automatic cherry-pick of ${MERGE_SHA} to ${RELEASE_VERSION} had conflicts.
Please cherry-pick manually:
git cherry-pick -x -m1 ${MERGE_SHA}"
fi
git push origin "$BACKPORT_BRANCH"
TITLE="${PR_TITLE} (#${PR_NUMBER})"
BODY=$(cat <<EOF
Backport of ${PR_URL}
Original PR: #${PR_NUMBER} — ${PR_TITLE}
Merge commit: ${MERGE_SHA}
EOF
)
if [ "$CONFLICTS" = true ]; then
TITLE="${TITLE} (conflicts)"
BODY="${BODY}
> [!WARNING]
> The automatic cherry-pick had conflicts.
> Please resolve manually by cherry-picking the original merge commit:
>
> \`\`\`
> git fetch origin ${BACKPORT_BRANCH}
> git checkout ${BACKPORT_BRANCH}
> git reset --hard origin/${RELEASE_VERSION}
> git cherry-pick -x -m1 ${MERGE_SHA}
> # resolve conflicts, then push
> \`\`\`"
fi
# Check if a PR already exists for this branch (idempotency
# for re-runs).
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_VERSION" --state all --json number --jq '.[0].number // empty')
if [ -n "$EXISTING_PR" ]; then
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
exit 0
fi
gh pr create \
--base "$RELEASE_VERSION" \
--head "$BACKPORT_BRANCH" \
--title "$TITLE" \
--body "$BODY"
+139
View File
@@ -0,0 +1,139 @@
# Automatically cherry-pick merged PRs to the latest release branch when the
# "cherry-pick" label is applied. Works whether the label is added before or
# after the PR is merged.
#
# Usage:
# 1. Add the "cherry-pick" label to a PR targeting main.
# 2. When the PR merges (or if already merged), the workflow detects the
# latest release/* branch and opens a cherry-pick PR against it.
#
# The created PRs follow existing repo conventions:
# - Branch: backport/<pr>-to-<version>
# - Title: <original PR title> (#<pr>)
# - Body: links back to the original PR and merge commit
name: Cherry-pick to release
on:
pull_request_target:
branches:
- main
types:
- closed
- labeled
permissions:
contents: write
pull-requests: write
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
# fire in quick succession.
concurrency:
group: cherry-pick-${{ github.event.pull_request.number }}
jobs:
cherry-pick:
name: Cherry-pick to latest release
if: >
github.event.pull_request.merged == true &&
contains(github.event.pull_request.labels.*.name, 'cherry-pick')
runs-on: ubuntu-latest
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
PR_TITLE: ${{ github.event.pull_request.title }}
PR_URL: ${{ github.event.pull_request.html_url }}
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
# Full history required for cherry-pick and branch discovery.
fetch-depth: 0
- name: Cherry-pick and open PR
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
# Find the latest release branch matching the exact release/2.X
# pattern (no suffixes like release/2.31_hotfix).
RELEASE_BRANCH=$(
git branch -r \
| grep -E '^\s*origin/release/2\.[0-9]+$' \
| sed 's|.*origin/||' \
| sort -t. -k2 -n -r \
| head -1
)
if [ -z "$RELEASE_BRANCH" ]; then
echo "::error::No release branch found."
exit 1
fi
# Strip the release/ prefix for naming.
VERSION="${RELEASE_BRANCH#release/}"
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
echo "Target branch: $RELEASE_BRANCH"
echo "Backport branch: $BACKPORT_BRANCH"
# Check if backport branch already exists (idempotency for re-runs).
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
echo "Branch ${BACKPORT_BRANCH} already exists, skipping."
exit 0
fi
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
# Create the backport branch from the target release branch.
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_BRANCH}"
# Cherry-pick the merge commit. Use -x to record provenance and
# -m1 to pick the first parent (the main branch side).
CONFLICT=false
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
CONFLICT=true
echo "::warning::Cherry-pick to ${RELEASE_BRANCH} had conflicts."
# Abort the failed cherry-pick and create an empty commit with
# instructions so the PR can still be opened.
git cherry-pick --abort
git commit --allow-empty -m "cherry-pick of #${PR_NUMBER} failed — resolve conflicts manually
Cherry-pick of ${MERGE_SHA} onto ${RELEASE_BRANCH} had conflicts.
To resolve:
git fetch origin ${BACKPORT_BRANCH}
git checkout ${BACKPORT_BRANCH}
git cherry-pick -x -m1 ${MERGE_SHA}
# resolve conflicts
git push origin ${BACKPORT_BRANCH}"
fi
git push origin "$BACKPORT_BRANCH"
BODY=$(cat <<EOF
Cherry-pick of ${PR_URL}
Original PR: #${PR_NUMBER} — ${PR_TITLE}
Merge commit: ${MERGE_SHA}
EOF
)
TITLE="${PR_TITLE} (#${PR_NUMBER})"
if [ "$CONFLICT" = true ]; then
TITLE="[CONFLICT] ${TITLE}"
fi
# Check if a PR already exists for this branch (idempotency
# for re-runs). Use --state all to catch closed/merged PRs too.
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_BRANCH" --state all --json number --jq '.[0].number // empty')
if [ -n "$EXISTING_PR" ]; then
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
exit 0
fi
gh pr create \
--base "$RELEASE_BRANCH" \
--head "$BACKPORT_BRANCH" \
--title "$TITLE" \
--body "$BODY"
+13 -13
View File
@@ -121,22 +121,22 @@ jobs:
fi
# Derive the release branch from the version tag.
# Standard: 2.10.2 -> release/2.10
# RC: 2.32.0-rc.0 -> release/2.32-rc.0
# Non-RC releases must be on a release/X.Y branch.
# RC tags are allowed on any branch (typically main).
version="$(./scripts/version.sh)"
# Strip any pre-release suffix first (e.g. 2.32.0-rc.0 -> 2.32.0)
base_version="${version%%-*}"
# Then strip patch to get major.minor (e.g. 2.32.0 -> 2.32)
release_branch="release/${base_version%.*}"
if [[ "$version" == *-rc.* ]]; then
# Extract major.minor and rc suffix from e.g. 2.32.0-rc.0
base_version="${version%%-rc.*}" # 2.32.0
major_minor="${base_version%.*}" # 2.32
rc_suffix="${version##*-rc.}" # 0
release_branch="release/${major_minor}-rc.${rc_suffix}"
echo "RC release detected — skipping release branch check (RC tags are cut from main)."
else
release_branch=release/${version%.*}
fi
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
if [[ -z "${branch_contains_tag}" ]]; then
echo "Ref tag must exist in a branch named ${release_branch} when creating a release, did you use scripts/release.sh?"
exit 1
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
if [[ -z "${branch_contains_tag}" ]]; then
echo "Ref tag must exist in a branch named ${release_branch} when creating a non-RC release, did you use scripts/release.sh?"
exit 1
fi
fi
if [[ -z "${CODER_RELEASE_NOTES}" ]]; then
+1
View File
@@ -36,6 +36,7 @@ typ = "typ"
styl = "styl"
edn = "edn"
Inferrable = "Inferrable"
IIF = "IIF"
[files]
extend-exclude = [
+3
View File
@@ -103,3 +103,6 @@ PLAN.md
# Ignore any dev licenses
license.txt
-e
# Agent planning documents (local working files).
docs/plans/
+14 -6
View File
@@ -102,6 +102,8 @@ type Options struct {
ReportMetadataInterval time.Duration
ServiceBannerRefreshInterval time.Duration
BlockFileTransfer bool
BlockReversePortForwarding bool
BlockLocalPortForwarding bool
Execer agentexec.Execer
Devcontainers bool
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
@@ -214,6 +216,8 @@ func New(options Options) Agent {
subsystems: options.Subsystems,
logSender: agentsdk.NewLogSender(options.Logger),
blockFileTransfer: options.BlockFileTransfer,
blockReversePortForwarding: options.BlockReversePortForwarding,
blockLocalPortForwarding: options.BlockLocalPortForwarding,
prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry),
@@ -280,6 +284,8 @@ type agent struct {
sshServer *agentssh.Server
sshMaxTimeout time.Duration
blockFileTransfer bool
blockReversePortForwarding bool
blockLocalPortForwarding bool
lifecycleUpdate chan struct{}
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
@@ -331,12 +337,14 @@ func (a *agent) TailnetConn() *tailnet.Conn {
func (a *agent) init() {
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
BlockFileTransfer: a.blockFileTransfer,
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
BlockFileTransfer: a.blockFileTransfer,
BlockReversePortForwarding: a.blockReversePortForwarding,
BlockLocalPortForwarding: a.blockLocalPortForwarding,
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
var connectionType proto.Connection_Type
switch magicType {
+155
View File
@@ -986,6 +986,161 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
requireEcho(t, conn)
}
func TestAgent_TCPLocalForwardingBlocked(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
rl, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer rl.Close()
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
require.True(t, valid)
remotePort := tcpAddr.Port
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockLocalPortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
require.ErrorContains(t, err, "administratively prohibited")
}
func TestAgent_TCPRemoteForwardingBlocked(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockReversePortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
localhost := netip.MustParseAddr("127.0.0.1")
randomPort := testutil.RandomPortNoListen(t)
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
_, err = sshClient.ListenTCP(addr)
require.ErrorContains(t, err, "tcpip-forward request denied by peer")
}
func TestAgent_UnixLocalForwardingBlocked(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
ctx := testutil.Context(t, testutil.WaitLong)
tmpdir := testutil.TempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
l, err := net.Listen("unix", remoteSocketPath)
require.NoError(t, err)
defer l.Close()
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockLocalPortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sshClient.Dial("unix", remoteSocketPath)
require.ErrorContains(t, err, "administratively prohibited")
}
func TestAgent_UnixRemoteForwardingBlocked(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
ctx := testutil.Context(t, testutil.WaitLong)
tmpdir := testutil.TempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockReversePortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sshClient.ListenUnix(remoteSocketPath)
require.ErrorContains(t, err, "streamlocal-forward@openssh.com request denied by peer")
}
// TestAgent_LocalBlockedDoesNotAffectReverse verifies that blocking
// local port forwarding does not prevent reverse port forwarding from
// working. A field-name transposition at any plumbing hop would cause
// both directions to be blocked when only one flag is set.
func TestAgent_LocalBlockedDoesNotAffectReverse(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockLocalPortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
// Reverse forwarding must still work.
localhost := netip.MustParseAddr("127.0.0.1")
var ll net.Listener
for {
randomPort := testutil.RandomPortNoListen(t)
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
ll, err = sshClient.ListenTCP(addr)
if err != nil {
t.Logf("error remote forwarding: %s", err.Error())
select {
case <-ctx.Done():
t.Fatal("timed out getting random listener")
default:
continue
}
}
break
}
_ = ll.Close()
}
// TestAgent_ReverseBlockedDoesNotAffectLocal verifies that blocking
// reverse port forwarding does not prevent local port forwarding from
// working.
func TestAgent_ReverseBlockedDoesNotAffectLocal(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
rl, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer rl.Close()
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
require.True(t, valid)
remotePort := tcpAddr.Port
go echoOnce(t, rl)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockReversePortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
// Local forwarding must still work.
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
require.NoError(t, err)
defer conn.Close()
requireEcho(t, conn)
}
func TestAgent_UnixLocalForwarding(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
+26 -3
View File
@@ -117,6 +117,10 @@ type Config struct {
X11MaxPort *int
// BlockFileTransfer restricts use of file transfer applications.
BlockFileTransfer bool
// BlockReversePortForwarding disables reverse port forwarding (ssh -R).
BlockReversePortForwarding bool
// BlockLocalPortForwarding disables local port forwarding (ssh -L).
BlockLocalPortForwarding bool
// ReportConnection.
ReportConnection reportConnectionFunc
// Experimental: allow connecting to running containers via Docker exec.
@@ -190,7 +194,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
}
forwardHandler := &ssh.ForwardedTCPHandler{}
unixForwardHandler := newForwardedUnixHandler(logger)
unixForwardHandler := newForwardedUnixHandler(logger, config.BlockReversePortForwarding)
metrics := newSSHServerMetrics(prometheusRegistry)
s := &Server{
@@ -229,8 +233,15 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains)
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
},
"direct-streamlocal@openssh.com": directStreamLocalHandler,
"session": ssh.DefaultSessionHandler,
"direct-streamlocal@openssh.com": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
if s.config.BlockLocalPortForwarding {
s.logger.Warn(ctx, "unix local port forward blocked")
_ = newChan.Reject(gossh.Prohibited, "local port forwarding is disabled")
return
}
directStreamLocalHandler(srv, conn, newChan, ctx)
},
"session": ssh.DefaultSessionHandler,
},
ConnectionFailedCallback: func(conn net.Conn, err error) {
s.logger.Warn(ctx, "ssh connection failed",
@@ -250,6 +261,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
// be set before we start listening.
HostSigners: []ssh.Signer{},
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
if s.config.BlockLocalPortForwarding {
s.logger.Warn(ctx, "local port forward blocked",
slog.F("destination_host", destinationHost),
slog.F("destination_port", destinationPort))
return false
}
// Allow local port forwarding all!
s.logger.Debug(ctx, "local port forward",
slog.F("destination_host", destinationHost),
@@ -260,6 +277,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
return true
},
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
if s.config.BlockReversePortForwarding {
s.logger.Warn(ctx, "reverse port forward blocked",
slog.F("bind_host", bindHost),
slog.F("bind_port", bindPort))
return false
}
// Allow reverse port forwarding all!
s.logger.Debug(ctx, "reverse port forward",
slog.F("bind_host", bindHost),
+11 -5
View File
@@ -35,8 +35,9 @@ type forwardedStreamLocalPayload struct {
// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding.
type forwardedUnixHandler struct {
sync.Mutex
log slog.Logger
forwards map[forwardKey]net.Listener
log slog.Logger
forwards map[forwardKey]net.Listener
blockReversePortForwarding bool
}
type forwardKey struct {
@@ -44,10 +45,11 @@ type forwardKey struct {
addr string
}
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
func newForwardedUnixHandler(log slog.Logger, blockReversePortForwarding bool) *forwardedUnixHandler {
return &forwardedUnixHandler{
log: log,
forwards: make(map[forwardKey]net.Listener),
log: log,
forwards: make(map[forwardKey]net.Listener),
blockReversePortForwarding: blockReversePortForwarding,
}
}
@@ -62,6 +64,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
switch req.Type {
case "streamlocal-forward@openssh.com":
if h.blockReversePortForwarding {
log.Warn(ctx, "unix reverse port forward blocked")
return false, nil
}
var reqPayload streamLocalForwardPayload
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
+22 -4
View File
@@ -53,6 +53,8 @@ func workspaceAgent() *serpent.Command {
slogJSONPath string
slogStackdriverPath string
blockFileTransfer bool
blockReversePortForwarding bool
blockLocalPortForwarding bool
agentHeaderCommand string
agentHeader []string
devcontainers bool
@@ -319,10 +321,12 @@ func workspaceAgent() *serpent.Command {
SSHMaxTimeout: sshMaxTimeout,
Subsystems: subsystems,
PrometheusRegistry: prometheusRegistry,
BlockFileTransfer: blockFileTransfer,
Execer: execer,
Devcontainers: devcontainers,
PrometheusRegistry: prometheusRegistry,
BlockFileTransfer: blockFileTransfer,
BlockReversePortForwarding: blockReversePortForwarding,
BlockLocalPortForwarding: blockLocalPortForwarding,
Execer: execer,
Devcontainers: devcontainers,
DevcontainerAPIOptions: []agentcontainers.Option{
agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()),
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
@@ -493,6 +497,20 @@ func workspaceAgent() *serpent.Command {
Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")),
Value: serpent.BoolOf(&blockFileTransfer),
},
{
Flag: "block-reverse-port-forwarding",
Default: "false",
Env: "CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING",
Description: "Block reverse port forwarding through the SSH server (ssh -R).",
Value: serpent.BoolOf(&blockReversePortForwarding),
},
{
Flag: "block-local-port-forwarding",
Default: "false",
Env: "CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING",
Description: "Block local port forwarding through the SSH server (ssh -L).",
Value: serpent.BoolOf(&blockLocalPortForwarding),
},
{
Flag: "devcontainers-enable",
Default: "true",
+20
View File
@@ -768,10 +768,30 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("create pubsub: %w", err)
}
options.Pubsub = ps
options.ChatPubsub = ps
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(ps)
}
defer options.Pubsub.Close()
chatPubsub, err := pubsub.NewBatching(
ctx,
logger.Named("chatd").Named("pubsub_batch"),
ps,
sqlDB,
dbURL,
pubsub.BatchingConfig{
FlushInterval: options.DeploymentValues.AI.Chat.PubsubFlushInterval.Value(),
QueueSize: int(options.DeploymentValues.AI.Chat.PubsubQueueSize.Value()),
},
)
if err != nil {
return xerrors.Errorf("create chat pubsub batcher: %w", err)
}
options.ChatPubsub = chatPubsub
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(chatPubsub)
}
defer options.ChatPubsub.Close()
psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps)
pubsubWatchdogTimeout = psWatchdog.Timeout()
defer psWatchdog.Close()
+97 -17
View File
@@ -52,6 +52,10 @@ import (
const (
disableUsageApp = "disable"
// Retry transient errors during SSH connection establishment.
sshRetryInterval = 2 * time.Second
sshMaxAttempts = 10 // initial + retries per step
)
var (
@@ -62,6 +66,51 @@ var (
workspaceNameRe = regexp.MustCompile(`[/.]+|--`)
)
// isRetryableError checks for transient connection errors worth
// retrying: DNS failures, connection refused, and server 5xx.
func isRetryableError(err error) bool {
if err == nil {
return false
}
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return false
}
if codersdk.IsConnectionError(err) {
return true
}
var sdkErr *codersdk.Error
if xerrors.As(err, &sdkErr) {
return sdkErr.StatusCode() >= 500
}
return false
}
// retryWithInterval calls fn up to maxAttempts times, waiting
// interval between attempts. Stops on success, non-retryable
// error, or context cancellation.
func retryWithInterval(ctx context.Context, logger slog.Logger, interval time.Duration, maxAttempts int, fn func() error) error {
var lastErr error
attempt := 0
for r := retry.New(interval, interval); r.Wait(ctx); {
lastErr = fn()
if lastErr == nil || !isRetryableError(lastErr) {
return lastErr
}
attempt++
if attempt >= maxAttempts {
break
}
logger.Warn(ctx, "transient error, retrying",
slog.Error(lastErr),
slog.F("attempt", attempt),
)
}
if lastErr != nil {
return lastErr
}
return ctx.Err()
}
func (r *RootCmd) ssh() *serpent.Command {
var (
stdio bool
@@ -277,10 +326,17 @@ func (r *RootCmd) ssh() *serpent.Command {
HostnameSuffix: hostnameSuffix,
}
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
ctx, inv, client,
inv.Args[0], cliConfig, disableAutostart)
if err != nil {
// Populated by the closure below.
var workspace codersdk.Workspace
var workspaceAgent codersdk.WorkspaceAgent
resolveWorkspace := func() error {
var err error
workspace, workspaceAgent, err = findWorkspaceAndAgentByHostname(
ctx, inv, client,
inv.Args[0], cliConfig, disableAutostart)
return err
}
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, resolveWorkspace); err != nil {
return err
}
@@ -306,8 +362,13 @@ func (r *RootCmd) ssh() *serpent.Command {
wait = false
}
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
if err != nil {
var templateVersion codersdk.TemplateVersion
fetchVersion := func() error {
var err error
templateVersion, err = client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
return err
}
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, fetchVersion); err != nil {
return err
}
@@ -347,8 +408,12 @@ func (r *RootCmd) ssh() *serpent.Command {
// If we're in stdio mode, check to see if we can use Coder Connect.
// We don't support Coder Connect over non-stdio coder ssh yet.
if stdio && !forceNewTunnel {
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
if err != nil {
var connInfo workspacesdk.AgentConnectionInfo
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
var err error
connInfo, err = wsClient.AgentConnectionInfoGeneric(ctx)
return err
}); err != nil {
return xerrors.Errorf("get agent connection info: %w", err)
}
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
@@ -384,23 +449,27 @@ func (r *RootCmd) ssh() *serpent.Command {
})
defer closeUsage()
}
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack, logger)
}
}
if r.disableDirect {
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
}
conn, err := wsClient.
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
var conn workspacesdk.AgentConn
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
var err error
conn, err = wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
Logger: logger,
BlockEndpoints: r.disableDirect,
EnableTelemetry: !r.disableNetworkTelemetry,
})
if err != nil {
return err
}); err != nil {
return xerrors.Errorf("dial agent: %w", err)
}
if err = stack.push("agent conn", conn); err != nil {
_ = conn.Close()
return err
}
conn.AwaitReachable(ctx)
@@ -1578,16 +1647,27 @@ func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDial
func testOrDefaultDialer(ctx context.Context) coderConnectDialer {
dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer)
if !ok || dialer == nil {
return &net.Dialer{}
// Timeout prevents hanging on broken tunnels (OS default is very long).
return &net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}
}
return dialer
}
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack, logger slog.Logger) error {
dialer := testOrDefaultDialer(ctx)
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return xerrors.Errorf("dial coder connect host: %w", err)
var conn net.Conn
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
var err error
conn, err = dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return xerrors.Errorf("dial coder connect host %q over tcp: %w", addr, err)
}
return nil
}); err != nil {
return err
}
if err := stack.push("tcp conn", conn); err != nil {
return err
+149 -1
View File
@@ -5,7 +5,9 @@ import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"sync"
"testing"
"time"
@@ -226,6 +228,41 @@ func TestCloserStack_Timeout(t *testing.T) {
testutil.TryReceive(ctx, t, closed)
}
func TestCloserStack_PushAfterClose_ConnClosed(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
uut.close(xerrors.New("canceled"))
closes := new([]*fakeCloser)
fc := &fakeCloser{closes: closes}
err := uut.push("conn", fc)
require.Error(t, err)
require.Equal(t, []*fakeCloser{fc}, *closes, "should close conn on failed push")
}
func TestCoderConnectDialer_DefaultTimeout(t *testing.T) {
t.Parallel()
ctx := context.Background()
dialer := testOrDefaultDialer(ctx)
d, ok := dialer.(*net.Dialer)
require.True(t, ok, "expected *net.Dialer")
assert.Equal(t, 5*time.Second, d.Timeout)
assert.Equal(t, 30*time.Second, d.KeepAlive)
}
func TestCoderConnectDialer_Overridden(t *testing.T) {
t.Parallel()
custom := &net.Dialer{Timeout: 99 * time.Second}
ctx := WithTestOnlyCoderConnectDialer(context.Background(), custom)
dialer := testOrDefaultDialer(ctx)
assert.Equal(t, custom, dialer)
}
func TestCoderConnectStdio(t *testing.T) {
t.Parallel()
@@ -254,7 +291,7 @@ func TestCoderConnectStdio(t *testing.T) {
stdioDone := make(chan struct{})
go func() {
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack, logger)
assert.NoError(t, err)
close(stdioDone)
}()
@@ -448,3 +485,114 @@ func Test_getWorkspaceAgent(t *testing.T) {
assert.Contains(t, err.Error(), "available agents: [clark krypton zod]")
})
}
func TestIsRetryableError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
retryable bool
}{
{"Nil", nil, false},
{"ContextCanceled", context.Canceled, false},
{"ContextDeadlineExceeded", context.DeadlineExceeded, false},
{"WrappedContextCanceled", xerrors.Errorf("wrapped: %w", context.Canceled), false},
{"DNSError", &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}, true},
{"OpError", &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{}}, true},
{"WrappedDNSError", xerrors.Errorf("connect: %w", &net.DNSError{Err: "no such host", Name: "example.com"}), true},
{"SDKError_500", codersdk.NewTestError(http.StatusInternalServerError, "GET", "/api"), true},
{"SDKError_502", codersdk.NewTestError(http.StatusBadGateway, "GET", "/api"), true},
{"SDKError_503", codersdk.NewTestError(http.StatusServiceUnavailable, "GET", "/api"), true},
{"SDKError_401", codersdk.NewTestError(http.StatusUnauthorized, "GET", "/api"), false},
{"SDKError_403", codersdk.NewTestError(http.StatusForbidden, "GET", "/api"), false},
{"SDKError_404", codersdk.NewTestError(http.StatusNotFound, "GET", "/api"), false},
{"GenericError", xerrors.New("something went wrong"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
})
}
}
func TestRetryWithInterval(t *testing.T) {
t.Parallel()
const interval = time.Millisecond
const maxAttempts = 3
dnsErr := &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
t.Run("Succeeds_FirstTry", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
return nil
})
require.NoError(t, err)
assert.Equal(t, 1, attempts)
})
t.Run("Succeeds_AfterTransientFailures", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
if attempts < 3 {
return dnsErr
}
return nil
})
require.NoError(t, err)
assert.Equal(t, 3, attempts)
})
t.Run("Stops_NonRetryableError", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
return xerrors.New("permanent failure")
})
require.ErrorContains(t, err, "permanent failure")
assert.Equal(t, 1, attempts)
})
t.Run("Stops_MaxAttemptsExhausted", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
return dnsErr
})
require.Error(t, err)
assert.Equal(t, maxAttempts, attempts)
})
t.Run("Stops_ContextCanceled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
cancel()
return dnsErr
})
require.Error(t, err)
assert.Equal(t, 1, attempts)
})
}
+6
View File
@@ -39,6 +39,12 @@ OPTIONS:
--block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false)
Block file transfer using known applications: nc,rsync,scp,sftp.
--block-local-port-forwarding bool, $CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING (default: false)
Block local port forwarding through the SSH server (ssh -L).
--block-reverse-port-forwarding bool, $CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING (default: false)
Block reverse port forwarding through the SSH server (ssh -R).
--boundary-log-proxy-socket-path string, $CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH (default: /tmp/boundary-audit.sock)
The path for the boundary log proxy server Unix socket. Boundary
should write audit logs to this socket.
+1
View File
@@ -134,6 +134,7 @@ func TestUserCreate(t *testing.T) {
{
name: "ServiceAccount",
args: []string{"--service-account", "-u", "dean"},
err: "Premium feature",
},
{
name: "ServiceAccountLoginType",
+3 -2
View File
@@ -77,8 +77,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
level := make([]database.LogLevel, 0)
outputLength := 0
for _, logEntry := range req.Logs {
output = append(output, logEntry.Output)
outputLength += len(logEntry.Output)
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
output = append(output, sanitizedOutput)
outputLength += len(sanitizedOutput)
var dbLevel database.LogLevel
switch logEntry.Level {
+53
View File
@@ -139,6 +139,59 @@ func TestBatchCreateLogs(t *testing.T) {
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
})
t.Run("SanitizesOutput", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
now := dbtime.Now()
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: testutil.Logger(t),
TimeNowFn: func() time.Time {
return now
},
}
rawOutput := "before\x00middle\xc3\x28after"
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
req := &agentproto.BatchCreateLogsRequest{
LogSourceId: logSource.ID[:],
Logs: []*agentproto.Log{
{
CreatedAt: timestamppb.New(now),
Level: agentproto.Log_WARN,
Output: rawOutput,
},
},
}
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
AgentID: agent.ID,
LogSourceID: logSource.ID,
CreatedAt: now,
Output: []string{sanitizedOutput},
Level: []database.LogLevel{database.LogLevelWarn},
OutputLength: expectedOutputLength,
}).Return([]database.WorkspaceAgentLog{
{
AgentID: agent.ID,
CreatedAt: now,
ID: 1,
Output: sanitizedOutput,
Level: database.LogLevelWarn,
LogSourceID: logSource.ID,
},
}, nil)
resp, err := api.BatchCreateLogs(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
})
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
t.Parallel()
+78
View File
@@ -1266,6 +1266,68 @@ const docTemplate = `{
]
}
},
"/experimental/chats/config/retention-days": {
"get": {
"produces": [
"application/json"
],
"tags": [
"Chats"
],
"summary": "Get chat retention days",
"operationId": "get-chat-retention-days",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
}
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
},
"put": {
"consumes": [
"application/json"
],
"tags": [
"Chats"
],
"summary": "Update chat retention days",
"operationId": "update-chat-retention-days",
"parameters": [
{
"description": "Request body",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
}
}
],
"responses": {
"204": {
"description": "No Content"
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
}
},
"/experimental/watch-all-workspacebuilds": {
"get": {
"produces": [
@@ -14420,6 +14482,14 @@ const docTemplate = `{
}
}
},
"codersdk.ChatRetentionDaysResponse": {
"type": "object",
"properties": {
"retention_days": {
"type": "integer"
}
}
},
"codersdk.ConnectionLatency": {
"type": "object",
"properties": {
@@ -20952,6 +21022,14 @@ const docTemplate = `{
}
}
},
"codersdk.UpdateChatRetentionDaysRequest": {
"type": "object",
"properties": {
"retention_days": {
"type": "integer"
}
}
},
"codersdk.UpdateCheckResponse": {
"type": "object",
"properties": {
+70
View File
@@ -1103,6 +1103,60 @@
]
}
},
"/experimental/chats/config/retention-days": {
"get": {
"produces": ["application/json"],
"tags": ["Chats"],
"summary": "Get chat retention days",
"operationId": "get-chat-retention-days",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
}
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
},
"put": {
"consumes": ["application/json"],
"tags": ["Chats"],
"summary": "Update chat retention days",
"operationId": "update-chat-retention-days",
"parameters": [
{
"description": "Request body",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
}
}
],
"responses": {
"204": {
"description": "No Content"
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
}
},
"/experimental/watch-all-workspacebuilds": {
"get": {
"produces": ["application/json"],
@@ -12963,6 +13017,14 @@
}
}
},
"codersdk.ChatRetentionDaysResponse": {
"type": "object",
"properties": {
"retention_days": {
"type": "integer"
}
}
},
"codersdk.ConnectionLatency": {
"type": "object",
"properties": {
@@ -19243,6 +19305,14 @@
}
}
},
"codersdk.UpdateChatRetentionDaysRequest": {
"type": "object",
"properties": {
"retention_days": {
"type": "integer"
}
}
},
"codersdk.UpdateCheckResponse": {
"type": "object",
"properties": {
+13 -2
View File
@@ -159,7 +159,10 @@ type Options struct {
Logger slog.Logger
Database database.Store
Pubsub pubsub.Pubsub
RuntimeConfig *runtimeconfig.Manager
// ChatPubsub allows chatd to use a dedicated publish path without changing
// the shared pubsub used by the rest of coderd.
ChatPubsub pubsub.Pubsub
RuntimeConfig *runtimeconfig.Manager
// CacheDir is used for caching files served by the API.
CacheDir string
@@ -777,6 +780,11 @@ func New(options *Options) *API {
maxChatsPerAcquire = math.MinInt32
}
chatPubsub := options.ChatPubsub
if chatPubsub == nil {
chatPubsub = options.Pubsub
}
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chatd"),
Database: options.Database,
@@ -789,7 +797,7 @@ func New(options *Options) *API {
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
Pubsub: chatPubsub,
WebpushDispatcher: options.WebPushDispatcher,
UsageTracker: options.WorkspaceUsageTracker,
})
@@ -1189,6 +1197,8 @@ func New(options *Options) *API {
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
r.Get("/retention-days", api.getChatRetentionDays)
r.Put("/retention-days", api.putChatRetentionDays)
r.Get("/template-allowlist", api.getChatTemplateAllowlist)
r.Put("/template-allowlist", api.putChatTemplateAllowlist)
})
@@ -1243,6 +1253,7 @@ func New(options *Options) *API {
r.Get("/git", api.watchChatGit)
})
r.Post("/interrupt", api.interruptChat)
r.Post("/tool-results", api.postChatToolResults)
r.Post("/title/regenerate", api.regenerateChatTitle)
r.Get("/diff", api.getChatDiffContents)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
+16 -5
View File
@@ -123,6 +123,10 @@ func UsersPagination(
require.Contains(t, gotUsers[0].Name, "after")
}
type UsersFilterOptions struct {
CreateServiceAccounts bool
}
// UsersFilter creates a set of users to run various filters against for
// testing. It can be used to test filtering both users and group members.
func UsersFilter(
@@ -130,11 +134,16 @@ func UsersFilter(
t *testing.T,
client *codersdk.Client,
db database.Store,
options *UsersFilterOptions,
setup func(users []codersdk.User),
fetch func(ctx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser,
) {
t.Helper()
if options == nil {
options = &UsersFilterOptions{}
}
firstUser, err := client.User(setupCtx, codersdk.Me)
require.NoError(t, err, "fetch me")
@@ -211,11 +220,13 @@ func UsersFilter(
}
// Add some service accounts.
for range 3 {
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
r.ServiceAccount = true
})
users = append(users, user)
if options.CreateServiceAccounts {
for range 3 {
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
r.ServiceAccount = true
})
users = append(users, user)
}
}
hashedPassword, err := userpassword.Hash("SomeStrongPassword!")
+38
View File
@@ -1715,3 +1715,41 @@ func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.
return result
}
// UserSecret converts a database ListUserSecretsRow (metadata only,
// no value) to an SDK UserSecret.
func UserSecret(secret database.ListUserSecretsRow) codersdk.UserSecret {
return codersdk.UserSecret{
ID: secret.ID,
Name: secret.Name,
Description: secret.Description,
EnvName: secret.EnvName,
FilePath: secret.FilePath,
CreatedAt: secret.CreatedAt,
UpdatedAt: secret.UpdatedAt,
}
}
// UserSecretFromFull converts a full database UserSecret row to an
// SDK UserSecret, omitting the value and encryption key ID.
func UserSecretFromFull(secret database.UserSecret) codersdk.UserSecret {
return codersdk.UserSecret{
ID: secret.ID,
Name: secret.Name,
Description: secret.Description,
EnvName: secret.EnvName,
FilePath: secret.FilePath,
CreatedAt: secret.CreatedAt,
UpdatedAt: secret.UpdatedAt,
}
}
// UserSecrets converts a slice of database ListUserSecretsRow to
// SDK UserSecret values.
func UserSecrets(secrets []database.ListUserSecretsRow) []codersdk.UserSecret {
result := make([]codersdk.UserSecret, 0, len(secrets))
for _, s := range secrets {
result = append(result, UserSecret(s))
}
return result
}
+4
View File
@@ -552,6 +552,10 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`),
Valid: true,
},
DynamicTools: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`[{"name":"tool1","description":"test tool","inputSchema":{"type":"object"}}]`),
Valid: true,
},
}
// Only ChatID is needed here. This test checks that
// Chat.DiffStatus is non-nil, not that every DiffStatus
+54
View File
@@ -2031,6 +2031,20 @@ func (q *querier) DeleteOldAuditLogs(ctx context.Context, arg database.DeleteOld
return q.db.DeleteOldAuditLogs(ctx, arg)
}
func (q *querier) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return 0, err
}
return q.db.DeleteOldChatFiles(ctx, arg)
}
func (q *querier) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return 0, err
}
return q.db.DeleteOldChats(ctx, arg)
}
func (q *querier) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return 0, err
@@ -2622,6 +2636,14 @@ func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.Ch
return msg, nil
}
func (q *querier) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
// Telemetry queries are called from system contexts only.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetChatMessageSummariesPerChat(ctx, createdAfter)
}
func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, arg.ChatID)
@@ -2670,6 +2692,14 @@ func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModel
return q.db.GetChatModelConfigs(ctx)
}
func (q *querier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
// Telemetry queries are called from system contexts only.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetChatModelConfigsForTelemetry(ctx)
}
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
@@ -2699,6 +2729,15 @@ func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) (
return q.db.GetChatQueuedMessages(ctx, chatID)
}
func (q *querier) GetChatRetentionDays(ctx context.Context) (int32, error) {
// Chat retention is a deployment-wide config read by dbpurge.
// Only requires a valid actor in context.
if _, ok := ActorFromContext(ctx); !ok {
return 0, ErrNoActor
}
return q.db.GetChatRetentionDays(ctx)
}
func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
// The system prompt is a deployment-wide setting read during chat
// creation by every authenticated user, so no RBAC policy check
@@ -2777,6 +2816,14 @@ func (q *querier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) (
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByWorkspaceIDs)(ctx, ids)
}
func (q *querier) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
// Telemetry queries are called from system contexts only.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetChatsUpdatedAfter(ctx, updatedAfter)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
// Just like with the audit logs query, shortcut if the user is an owner.
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
@@ -7031,6 +7078,13 @@ func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, incl
return q.db.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
}
func (q *querier) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatRetentionDays(ctx, retentionDays)
}
func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
+30
View File
@@ -600,6 +600,22 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
}))
s.Run("DeleteOldChatFiles", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().DeleteOldChatFiles(gomock.Any(), database.DeleteOldChatFilesParams{}).Return(int64(0), nil).AnyTimes()
check.Args(database.DeleteOldChatFilesParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
}))
s.Run("DeleteOldChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().DeleteOldChats(gomock.Any(), database.DeleteOldChatsParams{}).Return(int64(0), nil).AnyTimes()
check.Args(database.DeleteOldChatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
}))
s.Run("GetChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("UpsertChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatRetentionDays(gomock.Any(), int32(30)).Return(nil).AnyTimes()
check.Args(int32(30)).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
@@ -3996,6 +4012,20 @@ func (s *MethodTestSuite) TestSystemFunctions() {
dbm.EXPECT().GetWorkspaceAgentsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceAgent{}, nil).AnyTimes()
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetChatsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
ts := dbtime.Now()
dbm.EXPECT().GetChatsUpdatedAfter(gomock.Any(), ts).Return([]database.GetChatsUpdatedAfterRow{}, nil).AnyTimes()
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetChatMessageSummariesPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
ts := dbtime.Now()
dbm.EXPECT().GetChatMessageSummariesPerChat(gomock.Any(), ts).Return([]database.GetChatMessageSummariesPerChatRow{}, nil).AnyTimes()
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetChatModelConfigsForTelemetry", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatModelConfigsForTelemetry(gomock.Any()).Return([]database.GetChatModelConfigsForTelemetryRow{}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetWorkspaceAppsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
ts := dbtime.Now()
dbm.EXPECT().GetWorkspaceAppsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceApp{}, nil).AnyTimes()
+2
View File
@@ -1644,6 +1644,8 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
ClientSessionID: seed.ClientSessionID,
CredentialKind: takeFirst(seed.CredentialKind, database.CredentialKindCentralized),
CredentialHint: takeFirst(seed.CredentialHint, ""),
})
if endedAt != nil {
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
+56
View File
@@ -592,6 +592,22 @@ func (m queryMetricsStore) DeleteOldAuditLogs(ctx context.Context, arg database.
return r0, r1
}
func (m queryMetricsStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteOldChatFiles(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteOldChatFiles").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChatFiles").Inc()
return r0, r1
}
func (m queryMetricsStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteOldChats(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteOldChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChats").Inc()
return r0, r1
}
func (m queryMetricsStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteOldConnectionLogs(ctx, arg)
@@ -1160,6 +1176,14 @@ func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (da
return r0, r1
}
func (m queryMetricsStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessageSummariesPerChat(ctx, createdAfter)
m.queryLatencies.WithLabelValues("GetChatMessageSummariesPerChat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageSummariesPerChat").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
@@ -1208,6 +1232,14 @@ func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigsForTelemetry(ctx)
m.queryLatencies.WithLabelValues("GetChatModelConfigsForTelemetry").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigsForTelemetry").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByID(ctx, id)
@@ -1240,6 +1272,14 @@ func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uui
return r0, r1
}
func (m queryMetricsStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
start := time.Now()
r0, r1 := m.s.GetChatRetentionDays(ctx)
m.queryLatencies.WithLabelValues("GetChatRetentionDays").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatRetentionDays").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetChatSystemPrompt(ctx)
@@ -1312,6 +1352,14 @@ func (m queryMetricsStore) GetChatsByWorkspaceIDs(ctx context.Context, ids []uui
return r0, r1
}
func (m queryMetricsStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatsUpdatedAfter(ctx, updatedAfter)
m.queryLatencies.WithLabelValues("GetChatsUpdatedAfter").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsUpdatedAfter").Inc()
return r0, r1
}
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
start := time.Now()
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
@@ -5000,6 +5048,14 @@ func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Cont
return r0
}
func (m queryMetricsStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
start := time.Now()
r0 := m.s.UpsertChatRetentionDays(ctx, retentionDays)
m.queryLatencies.WithLabelValues("UpsertChatRetentionDays").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatRetentionDays").Inc()
return r0
}
func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertChatSystemPrompt(ctx, value)
+104
View File
@@ -984,6 +984,36 @@ func (mr *MockStoreMockRecorder) DeleteOldAuditLogs(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAuditLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAuditLogs), ctx, arg)
}
// DeleteOldChatFiles mocks base method.
func (m *MockStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteOldChatFiles", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteOldChatFiles indicates an expected call of DeleteOldChatFiles.
func (mr *MockStoreMockRecorder) DeleteOldChatFiles(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChatFiles", reflect.TypeOf((*MockStore)(nil).DeleteOldChatFiles), ctx, arg)
}
// DeleteOldChats mocks base method.
func (m *MockStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteOldChats", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteOldChats indicates an expected call of DeleteOldChats.
func (mr *MockStoreMockRecorder) DeleteOldChats(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChats", reflect.TypeOf((*MockStore)(nil).DeleteOldChats), ctx, arg)
}
// DeleteOldConnectionLogs mocks base method.
func (m *MockStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
m.ctrl.T.Helper()
@@ -2132,6 +2162,21 @@ func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageByID", reflect.TypeOf((*MockStore)(nil).GetChatMessageByID), ctx, id)
}
// GetChatMessageSummariesPerChat mocks base method.
func (m *MockStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatMessageSummariesPerChat", ctx, createdAfter)
ret0, _ := ret[0].([]database.GetChatMessageSummariesPerChatRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatMessageSummariesPerChat indicates an expected call of GetChatMessageSummariesPerChat.
func (mr *MockStoreMockRecorder) GetChatMessageSummariesPerChat(ctx, createdAfter any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageSummariesPerChat", reflect.TypeOf((*MockStore)(nil).GetChatMessageSummariesPerChat), ctx, createdAfter)
}
// GetChatMessagesByChatID mocks base method.
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
@@ -2222,6 +2267,21 @@ func (mr *MockStoreMockRecorder) GetChatModelConfigs(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigs), ctx)
}
// GetChatModelConfigsForTelemetry mocks base method.
func (m *MockStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatModelConfigsForTelemetry", ctx)
ret0, _ := ret[0].([]database.GetChatModelConfigsForTelemetryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatModelConfigsForTelemetry indicates an expected call of GetChatModelConfigsForTelemetry.
func (mr *MockStoreMockRecorder) GetChatModelConfigsForTelemetry(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigsForTelemetry", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigsForTelemetry), ctx)
}
// GetChatProviderByID mocks base method.
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
m.ctrl.T.Helper()
@@ -2282,6 +2342,21 @@ func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID)
}
// GetChatRetentionDays mocks base method.
func (m *MockStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatRetentionDays", ctx)
ret0, _ := ret[0].(int32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatRetentionDays indicates an expected call of GetChatRetentionDays.
func (mr *MockStoreMockRecorder) GetChatRetentionDays(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatRetentionDays", reflect.TypeOf((*MockStore)(nil).GetChatRetentionDays), ctx)
}
// GetChatSystemPrompt mocks base method.
func (m *MockStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
@@ -2417,6 +2492,21 @@ func (mr *MockStoreMockRecorder) GetChatsByWorkspaceIDs(ctx, ids any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByWorkspaceIDs", reflect.TypeOf((*MockStore)(nil).GetChatsByWorkspaceIDs), ctx, ids)
}
// GetChatsUpdatedAfter mocks base method.
func (m *MockStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatsUpdatedAfter", ctx, updatedAfter)
ret0, _ := ret[0].([]database.GetChatsUpdatedAfterRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatsUpdatedAfter indicates an expected call of GetChatsUpdatedAfter.
func (mr *MockStoreMockRecorder) GetChatsUpdatedAfter(ctx, updatedAfter any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).GetChatsUpdatedAfter), ctx, updatedAfter)
}
// GetConnectionLogsOffset mocks base method.
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
m.ctrl.T.Helper()
@@ -9399,6 +9489,20 @@ func (mr *MockStoreMockRecorder) UpsertChatIncludeDefaultSystemPrompt(ctx, inclu
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatIncludeDefaultSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatIncludeDefaultSystemPrompt), ctx, includeDefaultSystemPrompt)
}
// UpsertChatRetentionDays mocks base method.
func (m *MockStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatRetentionDays", ctx, retentionDays)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertChatRetentionDays indicates an expected call of UpsertChatRetentionDays.
func (mr *MockStoreMockRecorder) UpsertChatRetentionDays(ctx, retentionDays any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatRetentionDays", reflect.TypeOf((*MockStore)(nil).UpsertChatRetentionDays), ctx, retentionDays)
}
// UpsertChatSystemPrompt mocks base method.
func (m *MockStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
m.ctrl.T.Helper()
+49
View File
@@ -34,6 +34,11 @@ const (
// long enough to cover the maximum interval of a heartbeat event (currently
// 1 hour) plus some buffer.
maxTelemetryHeartbeatAge = 24 * time.Hour
// Batch sizes for chat purging. Both use 1000, which is smaller
// than audit/connection log batches (10000), because chat_files
// rows contain bytea blob data that make large batches heavier.
chatsBatchSize = 1000
chatFilesBatchSize = 1000
)
// New creates a new periodically purging database instance.
@@ -109,6 +114,17 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder
// purgeTick performs a single purge iteration. It returns an error if the
// purge fails.
func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.Time) error {
// Read chat retention config outside the transaction to
// avoid poisoning the tx if the stored value is corrupt.
// A SQL-level cast error (e.g. non-numeric text) puts PG
// into error state, failing all subsequent queries in the
// same transaction.
chatRetentionDays, err := db.GetChatRetentionDays(ctx)
if err != nil {
i.logger.Warn(ctx, "failed to read chat retention config, skipping chat purge", slog.Error(err))
chatRetentionDays = 0
}
// Start a transaction to grab advisory lock, we don't want to run
// multiple purges at the same time (multiple replicas).
return db.InTx(func(tx database.Store) error {
@@ -213,12 +229,43 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
}
}
// Chat retention is configured via site_configs. When
// enabled, old archived chats are deleted first, then
// orphaned chat files. Deleting a chat cascades to
// chat_file_links (removing references) but not to
// chat_files directly, so files from deleted chats
// become orphaned and are caught by DeleteOldChatFiles
// in the same tick.
var purgedChats int64
var purgedChatFiles int64
if chatRetentionDays > 0 {
chatRetention := time.Duration(chatRetentionDays) * 24 * time.Hour
deleteChatsBefore := start.Add(-chatRetention)
purgedChats, err = tx.DeleteOldChats(ctx, database.DeleteOldChatsParams{
BeforeTime: deleteChatsBefore,
LimitCount: chatsBatchSize,
})
if err != nil {
return xerrors.Errorf("failed to delete old chats: %w", err)
}
purgedChatFiles, err = tx.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
BeforeTime: deleteChatsBefore,
LimitCount: chatFilesBatchSize,
})
if err != nil {
return xerrors.Errorf("failed to delete old chat files: %w", err)
}
}
i.logger.Debug(ctx, "purged old database entries",
slog.F("workspace_agent_logs", purgedWorkspaceAgentLogs),
slog.F("expired_api_keys", expiredAPIKeys),
slog.F("aibridge_records", purgedAIBridgeRecords),
slog.F("connection_logs", purgedConnectionLogs),
slog.F("audit_logs", purgedAuditLogs),
slog.F("chats", purgedChats),
slog.F("chat_files", purgedChatFiles),
slog.F("duration", i.clk.Since(start)),
)
@@ -232,6 +279,8 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
i.recordsPurged.WithLabelValues("aibridge_records").Add(float64(purgedAIBridgeRecords))
i.recordsPurged.WithLabelValues("connection_logs").Add(float64(purgedConnectionLogs))
i.recordsPurged.WithLabelValues("audit_logs").Add(float64(purgedAuditLogs))
i.recordsPurged.WithLabelValues("chats").Add(float64(purgedChats))
i.recordsPurged.WithLabelValues("chat_files").Add(float64(purgedChatFiles))
}
return nil
+498
View File
@@ -12,6 +12,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -53,6 +54,7 @@ func TestPurge(t *testing.T) {
clk := quartz.NewMock(t)
done := awaitDoTick(ctx, t, clk)
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).Return(nil).Times(2)
purger := dbpurge.New(context.Background(), testutil.Logger(t), mDB, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
<-done // wait for doTick() to run.
@@ -125,6 +127,16 @@ func TestMetrics(t *testing.T) {
"record_type": "audit_logs",
})
require.GreaterOrEqual(t, auditLogs, 0)
chats := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
"record_type": "chats",
})
require.GreaterOrEqual(t, chats, 0)
chatFiles := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
"record_type": "chat_files",
})
require.GreaterOrEqual(t, chatFiles, 0)
})
t.Run("FailedIteration", func(t *testing.T) {
@@ -138,6 +150,7 @@ func TestMetrics(t *testing.T) {
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).
Return(xerrors.New("simulated database error")).
MinTimes(1)
@@ -1634,3 +1647,488 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
func ptr[T any](v T) *T {
return &v
}
//nolint:paralleltest // It uses LockIDDBPurge.
func TestDeleteOldChatFiles(t *testing.T) {
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
// createChatFile inserts a chat file and backdates created_at.
createChatFile := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, orgID uuid.UUID, createdAt time.Time) uuid.UUID {
t.Helper()
row, err := db.InsertChatFile(ctx, database.InsertChatFileParams{
OwnerID: ownerID,
OrganizationID: orgID,
Name: "test.png",
Mimetype: "image/png",
Data: []byte("fake-image-data"),
})
require.NoError(t, err)
_, err = rawDB.ExecContext(ctx, "UPDATE chat_files SET created_at = $1 WHERE id = $2", createdAt, row.ID)
require.NoError(t, err)
return row.ID
}
// createChat inserts a chat and optionally archives it, then
// backdates updated_at to control the "archived since" window.
createChat := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, modelConfigID uuid.UUID, archived bool, updatedAt time.Time) database.Chat {
t.Helper()
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: ownerID,
LastModelConfigID: modelConfigID,
Title: "test-chat",
Status: database.ChatStatusWaiting,
})
require.NoError(t, err)
if archived {
_, err = db.ArchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
}
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2", updatedAt, chat.ID)
require.NoError(t, err)
return chat
}
// setupChatDeps creates the common dependencies needed for
// chat-related tests: user, org, org member, provider, model config.
type chatDeps struct {
user database.User
org database.Organization
modelConfig database.ChatModelConfig
}
setupChatDeps := func(ctx context.Context, t *testing.T, db database.Store) chatDeps {
t.Helper()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
mc, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
ContextLimit: 8192,
Options: json.RawMessage("{}"),
})
require.NoError(t, err)
return chatDeps{user: user, org: org, modelConfig: mc}
}
tests := []struct {
name string
run func(t *testing.T)
}{
{
name: "ChatRetentionDisabled",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
clk := quartz.NewMock(t)
clk.Set(now).MustWait(ctx)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
deps := setupChatDeps(ctx, t, db)
// Disable retention.
err := db.UpsertChatRetentionDays(ctx, int32(0))
require.NoError(t, err)
// Create an old archived chat and an orphaned old file.
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
oldFileID := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
defer closer.Close()
testutil.TryReceive(ctx, t, done)
// Both should still exist.
_, err = db.GetChatByID(ctx, oldChat.ID)
require.NoError(t, err, "chat should not be deleted when retention is disabled")
_, err = db.GetChatFileByID(ctx, oldFileID)
require.NoError(t, err, "chat file should not be deleted when retention is disabled")
},
},
{
name: "OldArchivedChatsDeleted",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
clk := quartz.NewMock(t)
clk.Set(now).MustWait(ctx)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
deps := setupChatDeps(ctx, t, db)
err := db.UpsertChatRetentionDays(ctx, int32(30))
require.NoError(t, err)
// Old archived chat (31 days) — should be deleted.
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
// Insert a message so we can verify CASCADE.
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: oldChat.ID,
CreatedBy: []uuid.UUID{deps.user.ID},
ModelConfigID: []uuid.UUID{deps.modelConfig.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
Content: []string{`[{"type":"text","text":"hello"}]`},
ContentVersion: []int16{0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
ProviderResponseID: []string{""},
})
require.NoError(t, err)
// Recently archived chat (10 days) — should be retained.
recentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
// Active chat — should be retained.
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
defer closer.Close()
testutil.TryReceive(ctx, t, done)
// Old archived chat should be gone.
_, err = db.GetChatByID(ctx, oldChat.ID)
require.Error(t, err, "old archived chat should be deleted")
// Its messages should be gone too (CASCADE).
msgs, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: oldChat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Empty(t, msgs, "messages should be cascade-deleted")
// Recent archived and active chats should remain.
_, err = db.GetChatByID(ctx, recentChat.ID)
require.NoError(t, err, "recently archived chat should be retained")
_, err = db.GetChatByID(ctx, activeChat.ID)
require.NoError(t, err, "active chat should be retained")
},
},
{
name: "OrphanedOldFilesDeleted",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
clk := quartz.NewMock(t)
clk.Set(now).MustWait(ctx)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
deps := setupChatDeps(ctx, t, db)
err := db.UpsertChatRetentionDays(ctx, int32(30))
require.NoError(t, err)
// File A: 31 days old, NOT in any chat -> should be deleted.
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
// File B: 31 days old, in an active chat -> should be retained.
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: activeChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileB},
})
require.NoError(t, err)
// File C: 10 days old, NOT in any chat -> should be retained (too young).
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-10*24*time.Hour))
// File near boundary: 29d23h old — close to threshold.
fileBoundary := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-30*24*time.Hour).Add(time.Hour))
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
defer closer.Close()
testutil.TryReceive(ctx, t, done)
_, err = db.GetChatFileByID(ctx, fileA)
require.Error(t, err, "orphaned old file A should be deleted")
_, err = db.GetChatFileByID(ctx, fileB)
require.NoError(t, err, "file B in active chat should be retained")
_, err = db.GetChatFileByID(ctx, fileC)
require.NoError(t, err, "young file C should be retained")
_, err = db.GetChatFileByID(ctx, fileBoundary)
require.NoError(t, err, "file near 30d boundary should be retained")
},
},
{
name: "ArchivedChatFilesDeleted",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
clk := quartz.NewMock(t)
clk.Set(now).MustWait(ctx)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
deps := setupChatDeps(ctx, t, db)
err := db.UpsertChatRetentionDays(ctx, int32(30))
require.NoError(t, err)
// File D: 31 days old, in a chat archived 31 days ago -> should be deleted.
fileD := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
oldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: oldArchivedChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileD},
})
require.NoError(t, err)
// LinkChatFiles does not update chats.updated_at, so backdate.
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
now.Add(-31*24*time.Hour), oldArchivedChat.ID)
require.NoError(t, err)
// File E: 31 days old, in a chat archived 10 days ago -> should be retained.
fileE := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
recentArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: recentArchivedChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileE},
})
require.NoError(t, err)
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
now.Add(-10*24*time.Hour), recentArchivedChat.ID)
require.NoError(t, err)
// File F: 31 days old, in BOTH an active chat AND an old archived chat -> should be retained.
fileF := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
anotherOldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: anotherOldArchivedChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileF},
})
require.NoError(t, err)
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
now.Add(-31*24*time.Hour), anotherOldArchivedChat.ID)
require.NoError(t, err)
activeChatForF := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: activeChatForF.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileF},
})
require.NoError(t, err)
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
defer closer.Close()
testutil.TryReceive(ctx, t, done)
_, err = db.GetChatFileByID(ctx, fileD)
require.Error(t, err, "file D in old archived chat should be deleted")
_, err = db.GetChatFileByID(ctx, fileE)
require.NoError(t, err, "file E in recently archived chat should be retained")
_, err = db.GetChatFileByID(ctx, fileF)
require.NoError(t, err, "file F in active + old archived chat should be retained")
},
},
{
name: "UnarchiveAfterFilePurge",
run: func(t *testing.T) {
// Validates that when dbpurge deletes chat_files rows,
// the FK cascade on chat_file_links automatically
// removes the stale links. Unarchiving a chat after
// file purge should show only surviving files.
ctx := testutil.Context(t, testutil.WaitLong)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
deps := setupChatDeps(ctx, t, db)
// Create a chat with three attached files.
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
chat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
_, err := db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: chat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileA, fileB, fileC},
})
require.NoError(t, err)
// Archive the chat.
_, err = db.ArchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
// Simulate dbpurge deleting files A and B. The FK
// cascade on chat_file_links_file_id_fkey should
// automatically remove the corresponding link rows.
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)", pq.Array([]uuid.UUID{fileA, fileB}))
require.NoError(t, err)
// Unarchive the chat.
_, err = db.UnarchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
// Only file C should remain linked (FK cascade
// removed the links for deleted files A and B).
files, err := db.GetChatFileMetadataByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, files, 1, "only surviving file should be linked")
require.Equal(t, fileC, files[0].ID)
// Edge case: delete the last file too. The chat
// should have zero linked files, not an error.
_, err = db.ArchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = $1", fileC)
require.NoError(t, err)
_, err = db.UnarchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
files, err = db.GetChatFileMetadataByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Empty(t, files, "all-files-deleted should yield empty result")
// Test parent+child cascade: deleting files should
// clean up links for both parent and child chats
// independently via FK cascade.
parentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: deps.user.ID,
LastModelConfigID: deps.modelConfig.ID,
Title: "child-chat",
Status: database.ChatStatusWaiting,
})
require.NoError(t, err)
// Set root_chat_id to link child to parent.
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET root_chat_id = $1 WHERE id = $2", parentChat.ID, childChat.ID)
require.NoError(t, err)
// Attach different files to parent and child.
parentFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
parentFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
childFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
childFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: parentChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{parentFileKeep, parentFileStale},
})
require.NoError(t, err)
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: childChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{childFileKeep, childFileStale},
})
require.NoError(t, err)
// Archive via parent (cascades to child).
_, err = db.ArchiveChatByID(ctx, parentChat.ID)
require.NoError(t, err)
// Delete one file from each chat.
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)",
pq.Array([]uuid.UUID{parentFileStale, childFileStale}))
require.NoError(t, err)
// Unarchive via parent.
_, err = db.UnarchiveChatByID(ctx, parentChat.ID)
require.NoError(t, err)
parentFiles, err := db.GetChatFileMetadataByChatID(ctx, parentChat.ID)
require.NoError(t, err)
require.Len(t, parentFiles, 1)
require.Equal(t, parentFileKeep, parentFiles[0].ID,
"parent should retain only non-stale file")
childFiles, err := db.GetChatFileMetadataByChatID(ctx, childChat.ID)
require.NoError(t, err)
require.Len(t, childFiles, 1)
require.Equal(t, childFileKeep, childFiles[0].ID,
"child should retain only non-stale file")
},
},
{
name: "BatchLimitFiles",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
deps := setupChatDeps(ctx, t, db)
// Create 3 deletable orphaned files (all 31 days old).
for range 3 {
createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
}
// Delete with limit 2 — should delete 2, leave 1.
deleted, err := db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
BeforeTime: now.Add(-30 * 24 * time.Hour),
LimitCount: 2,
})
require.NoError(t, err)
require.Equal(t, int64(2), deleted, "should delete exactly 2 files")
// Delete again — should delete the remaining 1.
deleted, err = db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
BeforeTime: now.Add(-30 * 24 * time.Hour),
LimitCount: 2,
})
require.NoError(t, err)
require.Equal(t, int64(1), deleted, "should delete remaining 1 file")
},
},
{
name: "BatchLimitChats",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
deps := setupChatDeps(ctx, t, db)
// Create 3 deletable old archived chats.
for range 3 {
createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
}
// Delete with limit 2 — should delete 2, leave 1.
deleted, err := db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
BeforeTime: now.Add(-30 * 24 * time.Hour),
LimitCount: 2,
})
require.NoError(t, err)
require.Equal(t, int64(2), deleted, "should delete exactly 2 chats")
// Delete again — should delete the remaining 1.
deleted, err = db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
BeforeTime: now.Add(-30 * 24 * time.Hour),
LimitCount: 2,
})
require.NoError(t, err)
require.Equal(t, int64(1), deleted, "should delete remaining 1 chat")
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tc.run(t)
})
}
}
+16 -3
View File
@@ -293,7 +293,8 @@ CREATE TYPE chat_status AS ENUM (
'running',
'paused',
'completed',
'error'
'error',
'requires_action'
);
CREATE TYPE connection_status AS ENUM (
@@ -315,6 +316,11 @@ CREATE TYPE cors_behavior AS ENUM (
'passthru'
);
CREATE TYPE credential_kind AS ENUM (
'centralized',
'byok'
);
CREATE TYPE crypto_key_feature AS ENUM (
'workspace_apps_token',
'workspace_apps_api_key',
@@ -1101,7 +1107,9 @@ CREATE TABLE aibridge_interceptions (
thread_root_id uuid,
client_session_id character varying(256),
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL,
provider_name text DEFAULT ''::text NOT NULL
provider_name text DEFAULT ''::text NOT NULL,
credential_kind credential_kind DEFAULT 'centralized'::credential_kind NOT NULL,
credential_hint character varying(15) DEFAULT ''::character varying NOT NULL
);
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
@@ -1118,6 +1126,10 @@ COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related intercept
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
CREATE TABLE aibridge_model_thoughts (
interception_id uuid NOT NULL,
content text NOT NULL,
@@ -1418,7 +1430,8 @@ CREATE TABLE chats (
agent_id uuid,
pin_order integer DEFAULT 0 NOT NULL,
last_read_message_id bigint,
last_injected_context jsonb
last_injected_context jsonb,
dynamic_tools jsonb
);
CREATE TABLE connection_logs (
@@ -0,0 +1,31 @@
-- First update any rows using the value we're about to remove.
-- The column type is still the original chat_status at this point.
UPDATE chats SET status = 'error' WHERE status = 'requires_action';
-- Drop the column (this is independent of the enum).
ALTER TABLE chats DROP COLUMN IF EXISTS dynamic_tools;
-- Drop the partial index that references the chat_status enum type.
-- It must be removed before the rename-create-cast-drop cycle
-- because the index's WHERE clause (status = 'pending'::chat_status)
-- would otherwise cause a cross-type comparison failure.
DROP INDEX IF EXISTS idx_chats_pending;
-- Now recreate the enum without requires_action.
-- We must use the rename-create-cast-drop pattern.
ALTER TYPE chat_status RENAME TO chat_status_old;
CREATE TYPE chat_status AS ENUM (
'waiting',
'pending',
'running',
'paused',
'completed',
'error'
);
ALTER TABLE chats ALTER COLUMN status DROP DEFAULT;
ALTER TABLE chats ALTER COLUMN status TYPE chat_status USING status::text::chat_status;
ALTER TABLE chats ALTER COLUMN status SET DEFAULT 'waiting';
DROP TYPE chat_status_old;
-- Recreate the partial index.
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
@@ -0,0 +1,3 @@
ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'requires_action';
ALTER TABLE chats ADD COLUMN dynamic_tools JSONB DEFAULT NULL;
@@ -0,0 +1,5 @@
ALTER TABLE aibridge_interceptions
DROP COLUMN IF EXISTS credential_kind,
DROP COLUMN IF EXISTS credential_hint;
DROP TYPE IF EXISTS credential_kind;
@@ -0,0 +1,12 @@
CREATE TYPE credential_kind AS ENUM ('centralized', 'byok');
-- Records how each LLM request was authenticated and a masked credential
-- identifier for audit purposes. Existing rows default to 'centralized'
-- with an empty hint since we cannot retroactively determine their values.
ALTER TABLE aibridge_interceptions
ADD COLUMN credential_kind credential_kind NOT NULL DEFAULT 'centralized',
-- Length capped as a safety measure to ensure only masked values are stored.
ADD COLUMN credential_hint CHARACTER VARYING(15) NOT NULL DEFAULT '';
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
+5
View File
@@ -798,6 +798,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
&i.Chat.PinOrder,
&i.Chat.LastReadMessageID,
&i.Chat.LastInjectedContext,
&i.Chat.DynamicTools,
&i.HasUnread); err != nil {
return nil, err
}
@@ -868,6 +869,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
&i.AIBridgeInterception.ClientSessionID,
&i.AIBridgeInterception.SessionID,
&i.AIBridgeInterception.ProviderName,
&i.AIBridgeInterception.CredentialKind,
&i.AIBridgeInterception.CredentialHint,
&i.VisibleUser.ID,
&i.VisibleUser.Username,
&i.VisibleUser.Name,
@@ -1131,6 +1134,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, a
&i.AIBridgeInterception.ClientSessionID,
&i.AIBridgeInterception.SessionID,
&i.AIBridgeInterception.ProviderName,
&i.AIBridgeInterception.CredentialKind,
&i.AIBridgeInterception.CredentialHint,
); err != nil {
return nil, err
}
+73 -7
View File
@@ -1290,12 +1290,13 @@ func AllChatModeValues() []ChatMode {
type ChatStatus string
const (
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
ChatStatusRequiresAction ChatStatus = "requires_action"
)
func (e *ChatStatus) Scan(src interface{}) error {
@@ -1340,7 +1341,8 @@ func (e ChatStatus) Valid() bool {
ChatStatusRunning,
ChatStatusPaused,
ChatStatusCompleted,
ChatStatusError:
ChatStatusError,
ChatStatusRequiresAction:
return true
}
return false
@@ -1354,6 +1356,7 @@ func AllChatStatusValues() []ChatStatus {
ChatStatusPaused,
ChatStatusCompleted,
ChatStatusError,
ChatStatusRequiresAction,
}
}
@@ -1543,6 +1546,64 @@ func AllCorsBehaviorValues() []CorsBehavior {
}
}
type CredentialKind string
const (
CredentialKindCentralized CredentialKind = "centralized"
CredentialKindByok CredentialKind = "byok"
)
func (e *CredentialKind) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = CredentialKind(s)
case string:
*e = CredentialKind(s)
default:
return fmt.Errorf("unsupported scan type for CredentialKind: %T", src)
}
return nil
}
type NullCredentialKind struct {
CredentialKind CredentialKind `json:"credential_kind"`
Valid bool `json:"valid"` // Valid is true if CredentialKind is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullCredentialKind) Scan(value interface{}) error {
if value == nil {
ns.CredentialKind, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.CredentialKind.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullCredentialKind) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.CredentialKind), nil
}
func (e CredentialKind) Valid() bool {
switch e {
case CredentialKindCentralized,
CredentialKindByok:
return true
}
return false
}
func AllCredentialKindValues() []CredentialKind {
return []CredentialKind{
CredentialKindCentralized,
CredentialKindByok,
}
}
type CryptoKeyFeature string
const (
@@ -4040,6 +4101,10 @@ type AIBridgeInterception struct {
SessionID string `db:"session_id" json:"session_id"`
// The provider instance name which may differ from provider when multiple instances of the same provider type exist.
ProviderName string `db:"provider_name" json:"provider_name"`
// How the request was authenticated: centralized or byok.
CredentialKind CredentialKind `db:"credential_kind" json:"credential_kind"`
// Masked credential identifier for audit (e.g. sk-a***efgh).
CredentialHint string `db:"credential_hint" json:"credential_hint"`
}
// Audit log of model thinking in intercepted requests in AI Bridge
@@ -4180,6 +4245,7 @@ type Chat struct {
PinOrder int32 `db:"pin_order" json:"pin_order"`
LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"`
LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"`
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
}
type ChatDiffStatus struct {
+749
View File
@@ -0,0 +1,749 @@
package pubsub
import (
"bytes"
"context"
"database/sql"
"errors"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/quartz"
)
const (
// DefaultBatchingFlushInterval is the default upper bound on how long chatd
// publishes wait before a scheduled flush when nearby publishes do not
// naturally coalesce sooner.
DefaultBatchingFlushInterval = 50 * time.Millisecond
// DefaultBatchingQueueSize is the default number of buffered chatd publish
// requests waiting to be flushed.
DefaultBatchingQueueSize = 8192
defaultBatchingPressureWait = 10 * time.Millisecond
defaultBatchingFinalFlushLimit = 15 * time.Second
batchingWarnInterval = 10 * time.Second
batchFlushScheduled = "scheduled"
batchFlushShutdown = "shutdown"
batchFlushStageNone = "none"
batchFlushStageBegin = "begin"
batchFlushStageExec = "exec"
batchFlushStageCommit = "commit"
batchDelegateFallbackReasonQueueFull = "queue_full"
batchDelegateFallbackReasonFlushError = "flush_error"
batchChannelClassStreamNotify = "stream_notify"
batchChannelClassOwnerEvent = "owner_event"
batchChannelClassConfigChange = "config_change"
batchChannelClassOther = "other"
)
// ErrBatchingPubsubClosed is returned when a batched pubsub publish is
// attempted after shutdown has started.
var ErrBatchingPubsubClosed = xerrors.New("batched pubsub is closed")
// BatchingConfig controls the chatd-specific PostgreSQL pubsub batching path.
// Flush timing is automatic: the run loop wakes every FlushInterval (or on
// backpressure) and drains everything currently queued into a single
// transaction. There is no fixed batch-size knob — the batch size is simply
// whatever accumulated since the last flush, which naturally adapts to load.
type BatchingConfig struct {
FlushInterval time.Duration
QueueSize int
PressureWait time.Duration
FinalFlushTimeout time.Duration
Clock quartz.Clock
}
type queuedPublish struct {
event string
channelClass string
message []byte
}
type batchSender interface {
Flush(ctx context.Context, batch []queuedPublish) error
Close() error
}
type batchFlushError struct {
stage string
err error
}
func (e *batchFlushError) Error() string {
return e.err.Error()
}
func (e *batchFlushError) Unwrap() error {
return e.err
}
// BatchingPubsub batches chatd publish traffic onto a dedicated PostgreSQL
// sender connection while delegating subscribe behavior to the shared listener
// pubsub instance.
type BatchingPubsub struct {
logger slog.Logger
delegate *PGPubsub
// sender is only accessed from the run() goroutine (including
// flushBatch and resetSender which it calls). Do not read or
// write this field from Publish or any other goroutine.
sender batchSender
newSender func(context.Context) (batchSender, error)
clock quartz.Clock
publishCh chan queuedPublish
flushCh chan struct{}
closeCh chan struct{}
doneCh chan struct{}
spaceMu sync.Mutex
spaceSignal chan struct{}
warnTicker *quartz.Ticker
flushInterval time.Duration
pressureWait time.Duration
finalFlushTimeout time.Duration
queuedCount atomic.Int64
closed atomic.Bool
closeOnce sync.Once
closeErr error
runErr error
runCtx context.Context
cancel context.CancelFunc
metrics batchingMetrics
}
type batchingMetrics struct {
QueueDepth prometheus.Gauge
BatchSize prometheus.Histogram
FlushDuration *prometheus.HistogramVec
DelegateFallbacksTotal *prometheus.CounterVec
SenderResetsTotal prometheus.Counter
SenderResetFailuresTotal prometheus.Counter
}
func newBatchingMetrics() batchingMetrics {
return batchingMetrics{
QueueDepth: prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_queue_depth",
Help: "The number of chatd notifications waiting in the batching queue.",
}),
BatchSize: prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_size",
Help: "The number of logical notifications sent in each chatd batch flush.",
Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192},
}),
FlushDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_flush_duration_seconds",
Help: "The time spent flushing one chatd batch to PostgreSQL.",
Buckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 20, 30},
}, []string{"reason"}),
DelegateFallbacksTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_delegate_fallbacks_total",
Help: "The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage.",
}, []string{"channel_class", "reason", "stage"}),
SenderResetsTotal: prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_sender_resets_total",
Help: "The number of successful batched pubsub sender resets after flush failures.",
}),
SenderResetFailuresTotal: prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_sender_reset_failures_total",
Help: "The number of batched pubsub sender reset attempts that failed.",
}),
}
}
func (m batchingMetrics) Describe(descs chan<- *prometheus.Desc) {
m.QueueDepth.Describe(descs)
m.BatchSize.Describe(descs)
m.FlushDuration.Describe(descs)
m.DelegateFallbacksTotal.Describe(descs)
m.SenderResetsTotal.Describe(descs)
m.SenderResetFailuresTotal.Describe(descs)
}
func (m batchingMetrics) Collect(metrics chan<- prometheus.Metric) {
m.QueueDepth.Collect(metrics)
m.BatchSize.Collect(metrics)
m.FlushDuration.Collect(metrics)
m.DelegateFallbacksTotal.Collect(metrics)
m.SenderResetsTotal.Collect(metrics)
m.SenderResetFailuresTotal.Collect(metrics)
}
// NewBatching creates a chatd-specific batched pubsub wrapper around the
// shared PostgreSQL listener implementation.
func NewBatching(
ctx context.Context,
logger slog.Logger,
delegate *PGPubsub,
prototype *sql.DB,
connectURL string,
cfg BatchingConfig,
) (*BatchingPubsub, error) {
if delegate == nil {
return nil, xerrors.New("delegate pubsub is nil")
}
if prototype == nil {
return nil, xerrors.New("prototype database is nil")
}
if connectURL == "" {
return nil, xerrors.New("connect URL is empty")
}
newSender := func(ctx context.Context) (batchSender, error) {
return newPGBatchSender(ctx, logger.Named("sender"), prototype, connectURL)
}
sender, err := newSender(ctx)
if err != nil {
return nil, err
}
ps, err := newBatchingPubsub(logger, delegate, sender, cfg)
if err != nil {
_ = sender.Close()
return nil, err
}
ps.newSender = newSender
return ps, nil
}
func newBatchingPubsub(
logger slog.Logger,
delegate *PGPubsub,
sender batchSender,
cfg BatchingConfig,
) (*BatchingPubsub, error) {
if delegate == nil {
return nil, xerrors.New("delegate pubsub is nil")
}
if sender == nil {
return nil, xerrors.New("batch sender is nil")
}
flushInterval := cfg.FlushInterval
if flushInterval == 0 {
flushInterval = DefaultBatchingFlushInterval
}
if flushInterval < 0 {
return nil, xerrors.New("flush interval must be positive")
}
queueSize := cfg.QueueSize
if queueSize == 0 {
queueSize = DefaultBatchingQueueSize
}
if queueSize < 0 {
return nil, xerrors.New("queue size must be positive")
}
pressureWait := cfg.PressureWait
if pressureWait == 0 {
pressureWait = defaultBatchingPressureWait
}
if pressureWait < 0 {
return nil, xerrors.New("pressure wait must be positive")
}
finalFlushTimeout := cfg.FinalFlushTimeout
if finalFlushTimeout == 0 {
finalFlushTimeout = defaultBatchingFinalFlushLimit
}
if finalFlushTimeout < 0 {
return nil, xerrors.New("final flush timeout must be positive")
}
clock := cfg.Clock
if clock == nil {
clock = quartz.NewReal()
}
runCtx, cancel := context.WithCancel(context.Background())
ps := &BatchingPubsub{
logger: logger,
delegate: delegate,
sender: sender,
clock: clock,
publishCh: make(chan queuedPublish, queueSize),
flushCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
spaceSignal: make(chan struct{}),
warnTicker: clock.NewTicker(batchingWarnInterval, "pubsubBatcher", "warn"),
flushInterval: flushInterval,
pressureWait: pressureWait,
finalFlushTimeout: finalFlushTimeout,
runCtx: runCtx,
cancel: cancel,
metrics: newBatchingMetrics(),
}
ps.metrics.QueueDepth.Set(0)
go ps.run()
return ps, nil
}
// Describe implements prometheus.Collector.
func (p *BatchingPubsub) Describe(descs chan<- *prometheus.Desc) {
p.metrics.Describe(descs)
}
// Collect implements prometheus.Collector.
func (p *BatchingPubsub) Collect(metrics chan<- prometheus.Metric) {
p.metrics.Collect(metrics)
}
// Subscribe delegates to the shared PostgreSQL listener pubsub.
func (p *BatchingPubsub) Subscribe(event string, listener Listener) (func(), error) {
return p.delegate.Subscribe(event, listener)
}
// SubscribeWithErr delegates to the shared PostgreSQL listener pubsub.
func (p *BatchingPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (func(), error) {
return p.delegate.SubscribeWithErr(event, listener)
}
// Publish enqueues a logical notification for asynchronous batched delivery.
func (p *BatchingPubsub) Publish(event string, message []byte) error {
channelClass := batchChannelClass(event)
if p.closed.Load() {
return ErrBatchingPubsubClosed
}
req := queuedPublish{
event: event,
channelClass: channelClass,
message: bytes.Clone(message),
}
if p.tryEnqueue(req) {
return nil
}
timer := p.clock.NewTimer(p.pressureWait, "pubsubBatcher", "pressureWait")
defer timer.Stop("pubsubBatcher", "pressureWait")
for {
if p.closed.Load() {
return ErrBatchingPubsubClosed
}
p.signalPressureFlush()
spaceSignal := p.currentSpaceSignal()
if p.tryEnqueue(req) {
return nil
}
select {
case <-spaceSignal:
continue
case <-timer.C:
if p.tryEnqueue(req) {
return nil
}
// The batching queue is still full after a pressure
// flush and brief wait. Fall back to the shared
// pubsub pool so the notification is still delivered
// rather than dropped.
p.observeDelegateFallback(channelClass, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)
p.logPublishRejection(event)
return p.delegate.Publish(event, message)
case <-p.doneCh:
return ErrBatchingPubsubClosed
}
}
}
// Close stops accepting new publishes, performs a bounded best-effort drain,
// and then closes the dedicated sender connection.
func (p *BatchingPubsub) Close() error {
p.closeOnce.Do(func() {
p.closed.Store(true)
p.cancel()
p.notifySpaceAvailable()
close(p.closeCh)
<-p.doneCh
p.closeErr = p.runErr
})
return p.closeErr
}
func (p *BatchingPubsub) tryEnqueue(req queuedPublish) bool {
if p.closed.Load() {
return false
}
select {
case p.publishCh <- req:
queuedDepth := p.queuedCount.Add(1)
p.observeQueueDepth(queuedDepth)
return true
default:
return false
}
}
func (p *BatchingPubsub) observeQueueDepth(depth int64) {
p.metrics.QueueDepth.Set(float64(depth))
}
func (p *BatchingPubsub) signalPressureFlush() {
select {
case p.flushCh <- struct{}{}:
default:
}
}
func (p *BatchingPubsub) currentSpaceSignal() <-chan struct{} {
p.spaceMu.Lock()
defer p.spaceMu.Unlock()
return p.spaceSignal
}
func (p *BatchingPubsub) notifySpaceAvailable() {
p.spaceMu.Lock()
defer p.spaceMu.Unlock()
close(p.spaceSignal)
p.spaceSignal = make(chan struct{})
}
func batchChannelClass(event string) string {
switch {
case strings.HasPrefix(event, "chat:stream:"):
return batchChannelClassStreamNotify
case strings.HasPrefix(event, "chat:owner:"):
return batchChannelClassOwnerEvent
case event == "chat:config_change":
return batchChannelClassConfigChange
default:
return batchChannelClassOther
}
}
func (p *BatchingPubsub) observeDelegateFallback(channelClass string, reason string, stage string) {
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Inc()
}
func (p *BatchingPubsub) observeDelegateFallbackBatch(batch []queuedPublish, reason string, stage string) {
if len(batch) == 0 {
return
}
counts := make(map[string]int)
for _, item := range batch {
counts[item.channelClass]++
}
for channelClass, count := range counts {
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Add(float64(count))
}
}
func batchFlushStage(err error) string {
var flushErr *batchFlushError
if errors.As(err, &flushErr) {
return flushErr.stage
}
return "unknown"
}
func (p *BatchingPubsub) run() {
defer close(p.doneCh)
defer p.warnTicker.Stop("pubsubBatcher", "warn")
batch := make([]queuedPublish, 0, 64)
timer := p.clock.NewTimer(p.flushInterval, "pubsubBatcher", "scheduledFlush")
defer timer.Stop("pubsubBatcher", "scheduledFlush")
flush := func(reason string) {
batch = p.drainIntoBatch(batch)
batch, _ = p.flushBatch(p.runCtx, batch, reason)
timer.Reset(p.flushInterval, "pubsubBatcher", reason+"Flush")
}
for {
select {
case item := <-p.publishCh:
// An item arrived before the timer fired. Append it and
// let the timer or pressure signal trigger the actual
// flush so that nearby publishes coalesce naturally.
batch = append(batch, item)
p.notifySpaceAvailable()
case <-timer.C:
flush(batchFlushScheduled)
case <-p.flushCh:
flush("pressure")
case <-p.closeCh:
p.runErr = errors.Join(p.drain(batch), p.sender.Close())
return
}
}
}
func (p *BatchingPubsub) drainIntoBatch(batch []queuedPublish) []queuedPublish {
drained := false
for {
select {
case item := <-p.publishCh:
batch = append(batch, item)
drained = true
default:
if drained {
p.notifySpaceAvailable()
}
return batch
}
}
}
func (p *BatchingPubsub) flushBatch(
ctx context.Context,
batch []queuedPublish,
reason string,
) ([]queuedPublish, error) {
if len(batch) == 0 {
return batch[:0], nil
}
count := len(batch)
totalBytes := 0
for _, item := range batch {
totalBytes += len(item.message)
}
p.metrics.BatchSize.Observe(float64(count))
start := p.clock.Now()
senderErr := p.sender.Flush(ctx, batch)
elapsed := p.clock.Since(start)
p.metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
var err error
if senderErr != nil {
stage := batchFlushStage(senderErr)
delivered, failed, fallbackErr := p.replayBatchViaDelegate(batch, batchDelegateFallbackReasonFlushError, stage)
var resetErr error
if reason != batchFlushShutdown {
resetErr = p.resetSender()
}
p.logFlushFailure(reason, stage, count, totalBytes, delivered, failed, senderErr, fallbackErr, resetErr)
if fallbackErr != nil || resetErr != nil {
err = errors.Join(senderErr, fallbackErr, resetErr)
}
} else if p.delegate != nil {
p.delegate.publishesTotal.WithLabelValues("true").Add(float64(count))
p.delegate.publishedBytesTotal.Add(float64(totalBytes))
}
queuedDepth := p.queuedCount.Add(-int64(count))
p.observeQueueDepth(queuedDepth)
clear(batch)
return batch[:0], err
}
func (p *BatchingPubsub) replayBatchViaDelegate(batch []queuedPublish, reason string, stage string) (delivered int, failed int, err error) {
if len(batch) == 0 {
return 0, 0, nil
}
p.observeDelegateFallbackBatch(batch, reason, stage)
if p.delegate == nil {
return 0, len(batch), xerrors.New("delegate pubsub is nil")
}
var errs []error
for _, item := range batch {
if err := p.delegate.Publish(item.event, item.message); err != nil {
failed++
errs = append(errs, xerrors.Errorf("delegate publish %q: %w", item.event, err))
continue
}
delivered++
}
return delivered, failed, errors.Join(errs...)
}
func (p *BatchingPubsub) resetSender() error {
if p.newSender == nil {
return nil
}
newSender, err := p.newSender(context.Background())
if err != nil {
p.metrics.SenderResetFailuresTotal.Inc()
return err
}
oldSender := p.sender
p.sender = newSender
p.metrics.SenderResetsTotal.Inc()
if oldSender == nil {
return nil
}
if err := oldSender.Close(); err != nil {
p.logger.Warn(context.Background(), "failed to close old batched pubsub sender after reset", slog.Error(err))
}
return nil
}
func (p *BatchingPubsub) logFlushFailure(reason string, stage string, count int, totalBytes int, delivered int, failed int, senderErr error, fallbackErr error, resetErr error) {
fields := []slog.Field{
slog.F("reason", reason),
slog.F("stage", stage),
slog.F("count", count),
slog.F("total_bytes", totalBytes),
slog.F("delegate_delivered", delivered),
slog.F("delegate_failed", failed),
slog.Error(senderErr),
}
if fallbackErr != nil {
fields = append(fields, slog.F("delegate_error", fallbackErr.Error()))
}
if resetErr != nil {
fields = append(fields, slog.F("sender_reset_error", resetErr.Error()))
}
p.logger.Error(context.Background(), "batched pubsub flush failed", fields...)
}
func (p *BatchingPubsub) drain(batch []queuedPublish) error {
ctx, cancel := context.WithTimeout(context.Background(), p.finalFlushTimeout)
defer cancel()
var errs []error
for {
batch = p.drainIntoBatch(batch)
if len(batch) == 0 {
break
}
var err error
batch, err = p.flushBatch(ctx, batch, batchFlushShutdown)
if err != nil {
errs = append(errs, err)
}
if ctx.Err() != nil {
break
}
}
dropped := p.dropPendingPublishes()
if dropped > 0 {
errs = append(errs, xerrors.Errorf("dropped %d queued notifications during shutdown", dropped))
}
if ctx.Err() != nil {
errs = append(errs, xerrors.Errorf("shutdown flush timed out: %w", ctx.Err()))
}
return errors.Join(errs...)
}
func (p *BatchingPubsub) dropPendingPublishes() int {
count := 0
for {
select {
case <-p.publishCh:
count++
default:
if count > 0 {
queuedDepth := p.queuedCount.Add(-int64(count))
p.observeQueueDepth(queuedDepth)
}
return count
}
}
}
func (p *BatchingPubsub) logPublishRejection(event string) {
fields := []slog.Field{
slog.F("event", event),
slog.F("queue_size", cap(p.publishCh)),
slog.F("queued", p.queuedCount.Load()),
}
select {
case <-p.warnTicker.C:
p.logger.Warn(context.Background(), "batched pubsub queue is full", fields...)
default:
p.logger.Debug(context.Background(), "batched pubsub queue is full", fields...)
}
}
type pgBatchSender struct {
logger slog.Logger
db *sql.DB
}
func newPGBatchSender(
ctx context.Context,
logger slog.Logger,
prototype *sql.DB,
connectURL string,
) (*pgBatchSender, error) {
connector, err := newConnector(ctx, logger, prototype, connectURL)
if err != nil {
return nil, err
}
db := sql.OpenDB(connector)
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxIdleTime(0)
db.SetConnMaxLifetime(0)
pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := db.PingContext(pingCtx); err != nil {
_ = db.Close()
return nil, xerrors.Errorf("ping batched pubsub sender database: %w", err)
}
return &pgBatchSender{logger: logger, db: db}, nil
}
func (s *pgBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return &batchFlushError{stage: batchFlushStageBegin, err: xerrors.Errorf("begin batched pubsub transaction: %w", err)}
}
committed := false
defer func() {
if !committed {
_ = tx.Rollback()
}
}()
for _, item := range batch {
// This is safe because we are calling pq.QuoteLiteral. pg_notify does
// not support the first parameter being a prepared statement.
//nolint:gosec
_, err = tx.ExecContext(ctx, `select pg_notify(`+pq.QuoteLiteral(item.event)+`, $1)`, item.message)
if err != nil {
return &batchFlushError{stage: batchFlushStageExec, err: xerrors.Errorf("exec pg_notify: %w", err)}
}
}
if err := tx.Commit(); err != nil {
return &batchFlushError{stage: batchFlushStageCommit, err: xerrors.Errorf("commit batched pubsub transaction: %w", err)}
}
committed = true
return nil
}
func (s *pgBatchSender) Close() error {
return s.db.Close()
}
@@ -0,0 +1,520 @@
package pubsub
import (
"bytes"
"context"
"database/sql"
"sync"
"testing"
"time"
_ "github.com/lib/pq"
prom_testutil "github.com/prometheus/client_golang/prometheus/testutil"
dto "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestBatchingPubsubScheduledFlush(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
require.Empty(t, sender.Batches())
clock.Advance(10 * time.Millisecond).MustWait(ctx)
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
batch := testutil.TryReceive(ctx, t, sender.flushes)
require.Len(t, batch, 2)
require.Equal(t, []byte("one"), batch[0].message)
require.Equal(t, []byte("two"), batch[1].message)
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
require.Equal(t, uint64(1), batchSizeCount)
require.InDelta(t, 2, batchSizeSum, 0.000001)
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
require.Equal(t, uint64(1), flushDurationCount)
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
}
func TestBatchingPubsubDefaultConfigUsesDedicatedSenderFirstDefaults(t *testing.T) {
t.Parallel()
clock := quartz.NewMock(t)
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: clock})
require.Equal(t, DefaultBatchingFlushInterval, ps.flushInterval)
require.Equal(t, DefaultBatchingQueueSize, cap(ps.publishCh))
require.Equal(t, defaultBatchingPressureWait, ps.pressureWait)
require.Equal(t, defaultBatchingFinalFlushLimit, ps.finalFlushTimeout)
}
func TestBatchChannelClass(t *testing.T) {
t.Parallel()
tests := []struct {
name string
event string
want string
}{
{name: "stream notify", event: "chat:stream:123", want: batchChannelClassStreamNotify},
{name: "owner event", event: "chat:owner:123", want: batchChannelClassOwnerEvent},
{name: "config change", event: "chat:config_change", want: batchChannelClassConfigChange},
{name: "fallback", event: "workspace:owner:123", want: batchChannelClassOther},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, batchChannelClass(tt.event))
})
}
}
func TestBatchingPubsubTimerFlushDrainsAll(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 64,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
// Enqueue many messages before the timer fires — all should be
// drained and flushed in a single batch.
for _, msg := range []string{"one", "two", "three", "four", "five"} {
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
}
require.Empty(t, sender.Batches())
clock.Advance(10 * time.Millisecond).MustWait(ctx)
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
batch := testutil.TryReceive(ctx, t, sender.flushes)
require.Len(t, batch, 5)
require.Equal(t, []byte("one"), batch[0].message)
require.Equal(t, []byte("five"), batch[4].message)
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
}
func TestBatchingPubsubQueueFullFallsBackToDelegate(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
pressureTrap := clock.Trap().NewTimer("pubsubBatcher", "pressureWait")
defer pressureTrap.Close()
sender := newFakeBatchSender()
sender.blockCh = make(chan struct{})
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 1,
PressureWait: 10 * time.Millisecond,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
// Fill the queue (capacity 1).
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
// Fire the timer so the run loop starts flushing "one" — the
// sender blocks on blockCh so the flush stays in-flight.
clock.Advance(10 * time.Millisecond).MustWait(ctx)
<-sender.started
// The run loop is blocked in flushBatch. Fill the queue again.
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
// A third publish should fall back to the delegate (which has a
// closed db, so the delegate Publish itself will error — but we
// verify the fallback metric was incremented).
errCh := make(chan error, 1)
go func() {
errCh <- ps.Publish("chat:stream:a", []byte("three"))
}()
pressureCall, err := pressureTrap.Wait(ctx)
require.NoError(t, err)
pressureCall.MustRelease(ctx)
clock.Advance(10 * time.Millisecond).MustWait(ctx)
err = testutil.TryReceive(ctx, t, errCh)
// The delegate has a closed db so it returns an error from the
// shared pool, not a batching-specific sentinel.
require.Error(t, err)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)))
close(sender.blockCh)
// Let the run loop finish the blocked flush and process "two".
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
require.NoError(t, ps.Close())
}
func TestBatchingPubsubCloseDrainsQueue(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: time.Hour,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
require.NoError(t, ps.Publish("chat:stream:a", []byte("three")))
require.NoError(t, ps.Close())
batches := sender.Batches()
require.Len(t, batches, 1)
require.Len(t, batches[0], 3)
require.Equal(t, []byte("one"), batches[0][0].message)
require.Equal(t, []byte("two"), batches[0][1].message)
require.Equal(t, []byte("three"), batches[0][2].message)
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
require.Equal(t, 1, sender.CloseCalls())
}
func TestBatchingPubsubPreservesOrder(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: time.Hour,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
for _, msg := range []string{"one", "two", "three", "four", "five"} {
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
}
require.NoError(t, ps.Close())
batches := sender.Batches()
require.NotEmpty(t, batches)
messages := make([]string, 0, 5)
for _, batch := range batches {
for _, item := range batch {
messages = append(messages, string(item.message))
}
}
require.Equal(t, []string{"one", "two", "three", "four", "five"}, messages)
}
func TestBatchingPubsubFlushFailureMetrics(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
sender := newFakeBatchSender()
sender.err = context.DeadlineExceeded
sender.errStage = batchFlushStageExec
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
clock.Advance(10 * time.Millisecond).MustWait(ctx)
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
require.Equal(t, uint64(1), batchSizeCount)
require.InDelta(t, 1, batchSizeSum, 0.000001)
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
require.Equal(t, uint64(1), flushDurationCount)
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
require.Zero(t, prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("true")))
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
}
func TestBatchingPubsubFlushFailureStageAccounting(t *testing.T) {
t.Parallel()
stages := []string{batchFlushStageBegin, batchFlushStageExec, batchFlushStageCommit}
for _, stage := range stages {
stage := stage
t.Run(stage, func(t *testing.T) {
t.Parallel()
sender := newFakeBatchSender()
sender.err = context.DeadlineExceeded
sender.errStage = stage
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
batch := []queuedPublish{{
event: "chat:stream:test",
channelClass: batchChannelClass("chat:stream:test"),
message: []byte("fallback-" + stage),
}}
ps.queuedCount.Store(int64(len(batch)))
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
require.Error(t, err)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, stage)))
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
})
}
}
func TestBatchingPubsubFlushFailureResetSender(t *testing.T) {
t.Parallel()
clock := quartz.NewMock(t)
firstSender := newFakeBatchSender()
firstSender.err = context.DeadlineExceeded
firstSender.errStage = batchFlushStageExec
secondSender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, firstSender, BatchingConfig{Clock: clock})
ps.newSender = func(context.Context) (batchSender, error) {
return secondSender, nil
}
firstBatch := []queuedPublish{{
event: "chat:stream:first",
channelClass: batchChannelClass("chat:stream:first"),
message: []byte("first"),
}}
ps.queuedCount.Store(int64(len(firstBatch)))
_, err := ps.flushBatch(context.Background(), firstBatch, batchFlushScheduled)
require.Error(t, err)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.SenderResetsTotal))
require.Equal(t, 1, firstSender.CloseCalls())
secondBatch := []queuedPublish{{
event: "chat:stream:second",
channelClass: batchChannelClass("chat:stream:second"),
message: []byte("second"),
}}
ps.queuedCount.Store(int64(len(secondBatch)))
_, err = ps.flushBatch(context.Background(), secondBatch, batchFlushScheduled)
require.NoError(t, err)
batches := secondSender.Batches()
require.Len(t, batches, 1)
require.Len(t, batches[0], 1)
require.Equal(t, []byte("second"), batches[0][0].message)
}
func TestBatchingPubsubFlushFailureReturnsJoinedErrorWhenReplayFails(t *testing.T) {
t.Parallel()
sender := newFakeBatchSender()
sender.err = context.DeadlineExceeded
sender.errStage = batchFlushStageExec
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
batch := []queuedPublish{{
event: "chat:stream:error",
channelClass: batchChannelClass("chat:stream:error"),
message: []byte("error"),
}}
ps.queuedCount.Store(int64(len(batch)))
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
require.Error(t, err)
require.ErrorContains(t, err, context.DeadlineExceeded.Error())
require.ErrorContains(t, err, `delegate publish "chat:stream:error"`)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
}
func newTestBatchingPubsub(t *testing.T, sender batchSender, cfg BatchingConfig) (*BatchingPubsub, *PGPubsub) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
// Use a closed *sql.DB so that delegate.Publish returns a real
// error instead of panicking on a nil pointer when the batching
// queue falls back to the shared pool under pressure.
closedDB := newClosedDB(t)
delegate := newWithoutListener(logger.Named("delegate"), closedDB)
ps, err := newBatchingPubsub(logger.Named("batcher"), delegate, sender, cfg)
require.NoError(t, err)
t.Cleanup(func() {
_ = ps.Close()
})
return ps, delegate
}
// newClosedDB returns an *sql.DB whose connections have been closed,
// so any ExecContext call returns an error rather than panicking.
func newClosedDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("postgres", "host=localhost dbname=closed_db_stub sslmode=disable connect_timeout=1")
require.NoError(t, err)
require.NoError(t, db.Close())
return db
}
type fakeBatchSender struct {
mu sync.Mutex
batches [][]queuedPublish
flushes chan []queuedPublish
started chan struct{}
blockCh chan struct{}
err error
errStage string
closeErr error
closeCall int
}
func newFakeBatchSender() *fakeBatchSender {
return &fakeBatchSender{
flushes: make(chan []queuedPublish, 16),
started: make(chan struct{}, 16),
}
}
func (s *fakeBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
select {
case s.started <- struct{}{}:
default:
}
if s.blockCh != nil {
select {
case <-s.blockCh:
case <-ctx.Done():
return ctx.Err()
}
}
clone := make([]queuedPublish, len(batch))
for i, item := range batch {
clone[i] = queuedPublish{
event: item.event,
message: bytes.Clone(item.message),
}
}
s.mu.Lock()
s.batches = append(s.batches, clone)
s.mu.Unlock()
select {
case s.flushes <- clone:
default:
}
if s.err == nil {
return nil
}
if s.errStage != "" {
return &batchFlushError{stage: s.errStage, err: s.err}
}
return s.err
}
type metricWriter interface {
Write(*dto.Metric) error
}
func histogramCountAndSum(t *testing.T, observer any) (uint64, float64) {
t.Helper()
writer, ok := observer.(metricWriter)
require.True(t, ok)
metric := &dto.Metric{}
require.NoError(t, writer.Write(metric))
histogram := metric.GetHistogram()
require.NotNil(t, histogram)
return histogram.GetSampleCount(), histogram.GetSampleSum()
}
func (s *fakeBatchSender) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
s.closeCall++
return s.closeErr
}
func (s *fakeBatchSender) Batches() [][]queuedPublish {
s.mu.Lock()
defer s.mu.Unlock()
clone := make([][]queuedPublish, len(s.batches))
for i, batch := range s.batches {
clone[i] = make([]queuedPublish, len(batch))
copy(clone[i], batch)
}
return clone
}
func (s *fakeBatchSender) CloseCalls() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.closeCall
}
+130
View File
@@ -0,0 +1,130 @@
package pubsub_test
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/testutil"
)
func TestBatchingPubsubDedicatedSenderConnection(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
connectionURL, err := dbtestutil.Open(t)
require.NoError(t, err)
trackedDriver := dbtestutil.NewDriver()
defer trackedDriver.Close()
tconn, err := trackedDriver.Connector(connectionURL)
require.NoError(t, err)
trackedDB := sql.OpenDB(tconn)
defer trackedDB.Close()
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
require.NoError(t, err)
defer base.Close()
listenerConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
require.NoError(t, err)
defer batched.Close()
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
require.NotEqual(t, fmt.Sprintf("%p", listenerConn), fmt.Sprintf("%p", senderConn))
event := t.Name()
messageCh := make(chan []byte, 1)
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
messageCh <- message
})
require.NoError(t, err)
defer cancel()
require.NoError(t, batched.Publish(event, []byte("hello")))
require.Equal(t, []byte("hello"), testutil.TryReceive(ctx, t, messageCh))
}
func TestBatchingPubsubReconnectsAfterSenderDisconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
connectionURL, err := dbtestutil.Open(t)
require.NoError(t, err)
trackedDriver := dbtestutil.NewDriver()
defer trackedDriver.Close()
tconn, err := trackedDriver.Connector(connectionURL)
require.NoError(t, err)
trackedDB := sql.OpenDB(tconn)
defer trackedDB.Close()
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
require.NoError(t, err)
defer base.Close()
_ = testutil.TryReceive(ctx, t, trackedDriver.Connections) // listener connection
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
require.NoError(t, err)
defer batched.Close()
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
event := t.Name()
messageCh := make(chan []byte, 4)
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
messageCh <- message
})
require.NoError(t, err)
defer cancel()
require.NoError(t, batched.Publish(event, []byte("before-disconnect")))
require.Equal(t, []byte("before-disconnect"), testutil.TryReceive(ctx, t, messageCh))
require.NoError(t, senderConn.Close())
reconnected := false
delivered := false
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
if !reconnected {
select {
case conn := <-trackedDriver.Connections:
reconnected = conn != nil
default:
}
}
select {
case <-messageCh:
default:
}
if err := batched.Publish(event, []byte("after-disconnect")); err != nil {
return false
}
select {
case msg := <-messageCh:
delivered = string(msg) == "after-disconnect"
case <-time.After(testutil.IntervalFast):
delivered = false
}
return reconnected && delivered
}, testutil.IntervalMedium, "batched sender did not recover after disconnect")
}
+38 -36
View File
@@ -81,8 +81,8 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
}
func (q *msgQueue) run() {
var batch [maxDrainBatch]msgOrErr
for {
// wait until there is something on the queue or we are closed
q.cond.L.Lock()
for q.size == 0 && !q.closed {
q.cond.Wait()
@@ -91,32 +91,28 @@ func (q *msgQueue) run() {
q.cond.L.Unlock()
return
}
// Drain up to maxDrainBatch items while holding the lock.
n := min(q.size, maxDrainBatch)
for i := range n {
batch[i] = q.q[q.front]
q.front = (q.front + 1) % BufferSize
}
q.size -= n
item := q.q[q.front]
q.front = (q.front + 1) % BufferSize
q.size--
q.cond.L.Unlock()
// Dispatch each message individually without holding the lock.
for i := range n {
item := batch[i]
if item.err == nil {
if q.l != nil {
q.l(q.ctx, item.msg)
continue
}
if q.le != nil {
q.le(q.ctx, item.msg, nil)
continue
}
// process item without holding lock
if item.err == nil {
// real message
if q.l != nil {
q.l(q.ctx, item.msg)
continue
}
if q.le != nil {
q.le(q.ctx, nil, item.err)
q.le(q.ctx, item.msg, nil)
continue
}
// unhittable
continue
}
// if the listener wants errors, send it.
if q.le != nil {
q.le(q.ctx, nil, item.err)
}
}
}
@@ -237,12 +233,6 @@ type PGPubsub struct {
// for a subscriber before dropping messages.
const BufferSize = 2048
// maxDrainBatch is the maximum number of messages to drain from the ring
// buffer per iteration. Batching amortizes the cost of mutex
// acquire/release and cond.Wait across many messages, improving drain
// throughput during bursts.
const maxDrainBatch = 256
// Subscribe calls the listener when an event matching the name is received.
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
@@ -497,12 +487,14 @@ func (d logDialer) DialContext(ctx context.Context, network, address string) (ne
return conn, nil
}
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
p.connected.Set(0)
// Creates a new listener using pq.
func newConnector(ctx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (driver.Connector, error) {
if db == nil {
return nil, xerrors.New("database is nil")
}
var (
dialer = logDialer{
logger: p.logger,
logger: logger,
// pq.defaultDialer uses a zero net.Dialer as well.
d: net.Dialer{},
}
@@ -511,28 +503,38 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
)
// Create a custom connector if the database driver supports it.
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator)
connectorCreator, ok := db.Driver().(database.ConnectorCreator)
if ok {
connector, err = connectorCreator.Connector(connectURL)
if err != nil {
return xerrors.Errorf("create custom connector: %w", err)
return nil, xerrors.Errorf("create custom connector: %w", err)
}
} else {
// use the default pq connector otherwise
// Use the default pq connector otherwise.
connector, err = pq.NewConnector(connectURL)
if err != nil {
return xerrors.Errorf("create pq connector: %w", err)
return nil, xerrors.Errorf("create pq connector: %w", err)
}
}
// Set the dialer if the connector supports it.
dc, ok := connector.(database.DialerConnector)
if !ok {
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
} else {
dc.Dialer(dialer)
}
return connector, nil
}
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
p.connected.Set(0)
connector, err := newConnector(ctx, p.logger, p.db, connectURL)
if err != nil {
return err
}
var (
errCh = make(chan error, 1)
sentErrCh = false
+40 -2
View File
@@ -128,6 +128,22 @@ type sqlcQuerier interface {
// connection events (connect, disconnect, open, close) which are handled
// separately by DeleteOldAuditLogConnectionEvents.
DeleteOldAuditLogs(ctx context.Context, arg DeleteOldAuditLogsParams) (int64, error)
// TODO(cian): Add indexes on chats(archived, updated_at) and
// chat_files(created_at) for purge query performance.
// See: https://github.com/coder/internal/issues/1438
// Deletes chat files that are older than the given threshold and are
// not referenced by any chat that is still active or was archived
// within the same threshold window. This covers two cases:
// 1. Orphaned files not linked to any chat.
// 2. Files whose every referencing chat has been archived for longer
// than the retention period.
DeleteOldChatFiles(ctx context.Context, arg DeleteOldChatFilesParams) (int64, error)
// Deletes chats that have been archived for longer than the given
// threshold. Active (non-archived) chats are never deleted.
// Related chat_messages, chat_diff_statuses, and
// chat_queued_messages are removed via ON DELETE CASCADE.
// Parent/root references on child chats are SET NULL.
DeleteOldChats(ctx context.Context, arg DeleteOldChatsParams) (int64, error)
DeleteOldConnectionLogs(ctx context.Context, arg DeleteOldConnectionLogsParams) (int64, error)
// Delete all notification messages which have not been updated for over a week.
DeleteOldNotificationMessages(ctx context.Context) error
@@ -255,16 +271,27 @@ type sqlcQuerier interface {
// otherwise the setting defaults to true.
GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error)
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
// Aggregates message-level metrics per chat for messages created
// after the given timestamp. Uses message created_at so that
// ongoing activity in long-running chats is captured each window.
GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]GetChatMessageSummariesPerChatRow, error)
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error)
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
// Returns all model configurations for telemetry snapshot collection.
GetChatModelConfigsForTelemetry(ctx context.Context) ([]GetChatModelConfigsForTelemetryRow, error)
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
// Returns the chat retention period in days. Chats archived longer
// than this and orphaned chat files older than this are purged by
// dbpurge. Returns 30 (days) when no value has been configured.
// A value of 0 disables chat purging entirely.
GetChatRetentionDays(ctx context.Context) (int32, error)
GetChatSystemPrompt(ctx context.Context) (string, error)
// GetChatSystemPromptConfig returns both chat system prompt settings in a
// single read to avoid torn reads between separate site-config lookups.
@@ -283,6 +310,10 @@ type sqlcQuerier interface {
GetChatWorkspaceTTL(ctx context.Context) (string, error)
GetChats(ctx context.Context, arg GetChatsParams) ([]GetChatsRow, error)
GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]Chat, error)
// Retrieves chats updated after the given timestamp for telemetry
// snapshot collection. Uses updated_at so that long-running chats
// still appear in each snapshot window while they are active.
GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]GetChatsUpdatedAfterRow, error)
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
@@ -478,8 +509,10 @@ type sqlcQuerier interface {
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error)
GetRuntimeConfig(ctx context.Context, key string) (string, error)
// Find chats that appear stuck (running but heartbeat has expired).
// Used for recovery after coderd crashes or long hangs.
// Find chats that appear stuck and need recovery. This covers:
// 1. Running chats whose heartbeat has expired (worker crash).
// 2. Chats awaiting client action (requires_action) past the
// timeout threshold (client disappeared).
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
@@ -865,6 +898,10 @@ type sqlcQuerier interface {
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
// Unarchives a chat (and its children). Stale file references are
// handled automatically by FK cascades on chat_file_links: when
// dbpurge deletes a chat_files row, the corresponding
// chat_file_links rows are cascade-deleted by PostgreSQL.
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
// This will always work regardless of the current state of the template version.
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
@@ -1006,6 +1043,7 @@ type sqlcQuerier interface {
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
UpsertChatSystemPrompt(ctx context.Context, value string) error
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
+5 -4
View File
@@ -9085,10 +9085,11 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
insertParams := database.InsertAIBridgeInterceptionParams{
ID: uid,
InitiatorID: user.ID,
Metadata: json.RawMessage("{}"),
Client: sql.NullString{String: "client", Valid: true},
ID: uid,
InitiatorID: user.ID,
Metadata: json.RawMessage("{}"),
Client: sql.NullString{String: "client", Valid: true},
CredentialKind: database.CredentialKindCentralized,
}
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
+386 -36
View File
@@ -442,7 +442,7 @@ func (q *sqlQuerier) DeleteOldAIBridgeRecords(ctx context.Context, beforeTime ti
const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one
SELECT
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name, credential_kind, credential_hint
FROM
aibridge_interceptions
WHERE
@@ -467,6 +467,8 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU
&i.ClientSessionID,
&i.SessionID,
&i.ProviderName,
&i.CredentialKind,
&i.CredentialHint,
)
return i, err
}
@@ -501,7 +503,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Cont
const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many
SELECT
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name, credential_kind, credential_hint
FROM
aibridge_interceptions
`
@@ -530,6 +532,8 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn
&i.ClientSessionID,
&i.SessionID,
&i.ProviderName,
&i.CredentialKind,
&i.CredentialHint,
); err != nil {
return nil, err
}
@@ -678,11 +682,11 @@ func (q *sqlQuerier) GetAIBridgeUserPromptsByInterceptionID(ctx context.Context,
const insertAIBridgeInterception = `-- name: InsertAIBridgeInterception :one
INSERT INTO aibridge_interceptions (
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id, credential_kind, credential_hint
) VALUES (
$1, $2, $3, $4, $5, $6, COALESCE($7::jsonb, '{}'::jsonb), $8, $9, $10, $11::uuid, $12::uuid
$1, $2, $3, $4, $5, $6, COALESCE($7::jsonb, '{}'::jsonb), $8, $9, $10, $11::uuid, $12::uuid, $13, $14
)
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name, credential_kind, credential_hint
`
type InsertAIBridgeInterceptionParams struct {
@@ -698,6 +702,8 @@ type InsertAIBridgeInterceptionParams struct {
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
ThreadParentInterceptionID uuid.NullUUID `db:"thread_parent_interception_id" json:"thread_parent_interception_id"`
ThreadRootInterceptionID uuid.NullUUID `db:"thread_root_interception_id" json:"thread_root_interception_id"`
CredentialKind CredentialKind `db:"credential_kind" json:"credential_kind"`
CredentialHint string `db:"credential_hint" json:"credential_hint"`
}
func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertAIBridgeInterceptionParams) (AIBridgeInterception, error) {
@@ -714,6 +720,8 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
arg.ClientSessionID,
arg.ThreadParentInterceptionID,
arg.ThreadRootInterceptionID,
arg.CredentialKind,
arg.CredentialHint,
)
var i AIBridgeInterception
err := row.Scan(
@@ -731,6 +739,8 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
&i.ClientSessionID,
&i.SessionID,
&i.ProviderName,
&i.CredentialKind,
&i.CredentialHint,
)
return i, err
}
@@ -963,7 +973,7 @@ func (q *sqlQuerier) ListAIBridgeClients(ctx context.Context, arg ListAIBridgeCl
const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many
SELECT
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id, aibridge_interceptions.provider_name,
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id, aibridge_interceptions.provider_name, aibridge_interceptions.credential_kind, aibridge_interceptions.credential_hint,
visible_users.id, visible_users.username, visible_users.name, visible_users.avatar_url
FROM
aibridge_interceptions
@@ -1077,6 +1087,8 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
&i.AIBridgeInterception.ClientSessionID,
&i.AIBridgeInterception.SessionID,
&i.AIBridgeInterception.ProviderName,
&i.AIBridgeInterception.CredentialKind,
&i.AIBridgeInterception.CredentialHint,
&i.VisibleUser.ID,
&i.VisibleUser.Username,
&i.VisibleUser.Name,
@@ -1272,7 +1284,7 @@ WITH paginated_threads AS (
)
SELECT
COALESCE(aibridge_interceptions.thread_root_id, aibridge_interceptions.id) AS thread_id,
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id, aibridge_interceptions.provider_name
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id, aibridge_interceptions.provider_name, aibridge_interceptions.credential_kind, aibridge_interceptions.credential_hint
FROM
aibridge_interceptions
JOIN
@@ -1334,6 +1346,8 @@ func (q *sqlQuerier) ListAIBridgeSessionThreads(ctx context.Context, arg ListAIB
&i.AIBridgeInterception.ClientSessionID,
&i.AIBridgeInterception.SessionID,
&i.AIBridgeInterception.ProviderName,
&i.AIBridgeInterception.CredentialKind,
&i.AIBridgeInterception.CredentialHint,
); err != nil {
return nil, err
}
@@ -1718,7 +1732,7 @@ UPDATE aibridge_interceptions
WHERE
id = $2::uuid
AND ended_at IS NULL
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name, credential_kind, credential_hint
`
type UpdateAIBridgeInterceptionEndedParams struct {
@@ -1744,6 +1758,8 @@ func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg Up
&i.ClientSessionID,
&i.SessionID,
&i.ProviderName,
&i.CredentialKind,
&i.CredentialHint,
)
return i, err
}
@@ -2884,6 +2900,54 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou
return new_period, err
}
const deleteOldChatFiles = `-- name: DeleteOldChatFiles :execrows
WITH kept_file_ids AS (
-- NOTE: This uses updated_at as a proxy for archive time
-- because there is no archived_at column. Correctness
-- requires that updated_at is never backdated on archived
-- chats. See ArchiveChatByID.
SELECT DISTINCT cfl.file_id
FROM chat_file_links cfl
JOIN chats c ON c.id = cfl.chat_id
WHERE c.archived = false
OR c.updated_at >= $1::timestamptz
),
deletable AS (
SELECT cf.id
FROM chat_files cf
LEFT JOIN kept_file_ids k ON cf.id = k.file_id
WHERE cf.created_at < $1::timestamptz
AND k.file_id IS NULL
ORDER BY cf.created_at ASC
LIMIT $2
)
DELETE FROM chat_files
USING deletable
WHERE chat_files.id = deletable.id
`
type DeleteOldChatFilesParams struct {
BeforeTime time.Time `db:"before_time" json:"before_time"`
LimitCount int32 `db:"limit_count" json:"limit_count"`
}
// TODO(cian): Add indexes on chats(archived, updated_at) and
// chat_files(created_at) for purge query performance.
// See: https://github.com/coder/internal/issues/1438
// Deletes chat files that are older than the given threshold and are
// not referenced by any chat that is still active or was archived
// within the same threshold window. This covers two cases:
// 1. Orphaned files not linked to any chat.
// 2. Files whose every referencing chat has been archived for longer
// than the retention period.
func (q *sqlQuerier) DeleteOldChatFiles(ctx context.Context, arg DeleteOldChatFilesParams) (int64, error) {
result, err := q.db.ExecContext(ctx, deleteOldChatFiles, arg.BeforeTime, arg.LimitCount)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
const getChatFileByID = `-- name: GetChatFileByID :one
SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = $1::uuid
`
@@ -4180,7 +4244,7 @@ WHERE
$3::int
)
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type AcquireChatsParams struct {
@@ -4224,6 +4288,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) (
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
); err != nil {
return nil, err
}
@@ -4362,9 +4427,9 @@ WITH chats AS (
UPDATE chats
SET archived = true, pin_order = 0, updated_at = NOW()
WHERE id = $1::uuid OR root_chat_id = $1::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
)
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM chats
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
`
@@ -4402,6 +4467,7 @@ func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat,
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
); err != nil {
return nil, err
}
@@ -4504,9 +4570,42 @@ func (q *sqlQuerier) DeleteChatUsageLimitUserOverride(ctx context.Context, userI
return err
}
const deleteOldChats = `-- name: DeleteOldChats :execrows
WITH deletable AS (
SELECT id
FROM chats
WHERE archived = true
AND updated_at < $1::timestamptz
ORDER BY updated_at ASC
LIMIT $2
)
DELETE FROM chats
USING deletable
WHERE chats.id = deletable.id
AND chats.archived = true
`
type DeleteOldChatsParams struct {
BeforeTime time.Time `db:"before_time" json:"before_time"`
LimitCount int32 `db:"limit_count" json:"limit_count"`
}
// Deletes chats that have been archived for longer than the given
// threshold. Active (non-archived) chats are never deleted.
// Related chat_messages, chat_diff_statuses, and
// chat_queued_messages are removed via ON DELETE CASCADE.
// Parent/root references on child chats are SET NULL.
func (q *sqlQuerier) DeleteOldChats(ctx context.Context, arg DeleteOldChatsParams) (int64, error) {
result, err := q.db.ExecContext(ctx, deleteOldChats, arg.BeforeTime, arg.LimitCount)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
const getChatByID = `-- name: GetChatByID :one
SELECT
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM
chats
WHERE
@@ -4540,12 +4639,13 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context FROM chats WHERE id = $1::uuid FOR UPDATE
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats WHERE id = $1::uuid FOR UPDATE
`
func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) {
@@ -4575,6 +4675,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -5104,6 +5205,89 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess
return i, err
}
const getChatMessageSummariesPerChat = `-- name: GetChatMessageSummariesPerChat :many
SELECT
cm.chat_id,
COUNT(*)::bigint AS message_count,
COUNT(*) FILTER (WHERE cm.role = 'user')::bigint AS user_message_count,
COUNT(*) FILTER (WHERE cm.role = 'assistant')::bigint AS assistant_message_count,
COUNT(*) FILTER (WHERE cm.role = 'tool')::bigint AS tool_message_count,
COUNT(*) FILTER (WHERE cm.role = 'system')::bigint AS system_message_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(SUM(cm.reasoning_tokens), 0)::bigint AS total_reasoning_tokens,
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms,
COUNT(DISTINCT cm.model_config_id)::bigint AS distinct_model_count,
COUNT(*) FILTER (WHERE cm.compressed)::bigint AS compressed_message_count
FROM chat_messages cm
WHERE cm.created_at > $1
AND cm.deleted = false
GROUP BY cm.chat_id
`
type GetChatMessageSummariesPerChatRow struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
MessageCount int64 `db:"message_count" json:"message_count"`
UserMessageCount int64 `db:"user_message_count" json:"user_message_count"`
AssistantMessageCount int64 `db:"assistant_message_count" json:"assistant_message_count"`
ToolMessageCount int64 `db:"tool_message_count" json:"tool_message_count"`
SystemMessageCount int64 `db:"system_message_count" json:"system_message_count"`
TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"`
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
TotalReasoningTokens int64 `db:"total_reasoning_tokens" json:"total_reasoning_tokens"`
TotalCacheCreationTokens int64 `db:"total_cache_creation_tokens" json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `db:"total_cache_read_tokens" json:"total_cache_read_tokens"`
TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"`
TotalRuntimeMs int64 `db:"total_runtime_ms" json:"total_runtime_ms"`
DistinctModelCount int64 `db:"distinct_model_count" json:"distinct_model_count"`
CompressedMessageCount int64 `db:"compressed_message_count" json:"compressed_message_count"`
}
// Aggregates message-level metrics per chat for messages created
// after the given timestamp. Uses message created_at so that
// ongoing activity in long-running chats is captured each window.
func (q *sqlQuerier) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]GetChatMessageSummariesPerChatRow, error) {
rows, err := q.db.QueryContext(ctx, getChatMessageSummariesPerChat, createdAfter)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChatMessageSummariesPerChatRow
for rows.Next() {
var i GetChatMessageSummariesPerChatRow
if err := rows.Scan(
&i.ChatID,
&i.MessageCount,
&i.UserMessageCount,
&i.AssistantMessageCount,
&i.ToolMessageCount,
&i.SystemMessageCount,
&i.TotalInputTokens,
&i.TotalOutputTokens,
&i.TotalReasoningTokens,
&i.TotalCacheCreationTokens,
&i.TotalCacheReadTokens,
&i.TotalCostMicros,
&i.TotalRuntimeMs,
&i.DistinctModelCount,
&i.CompressedMessageCount,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
@@ -5409,6 +5593,52 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI
return items, nil
}
const getChatModelConfigsForTelemetry = `-- name: GetChatModelConfigsForTelemetry :many
SELECT id, provider, model, context_limit, enabled, is_default
FROM chat_model_configs
WHERE deleted = false
`
type GetChatModelConfigsForTelemetryRow struct {
ID uuid.UUID `db:"id" json:"id"`
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
ContextLimit int64 `db:"context_limit" json:"context_limit"`
Enabled bool `db:"enabled" json:"enabled"`
IsDefault bool `db:"is_default" json:"is_default"`
}
// Returns all model configurations for telemetry snapshot collection.
func (q *sqlQuerier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]GetChatModelConfigsForTelemetryRow, error) {
rows, err := q.db.QueryContext(ctx, getChatModelConfigsForTelemetry)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChatModelConfigsForTelemetryRow
for rows.Next() {
var i GetChatModelConfigsForTelemetryRow
if err := rows.Scan(
&i.ID,
&i.Provider,
&i.Model,
&i.ContextLimit,
&i.Enabled,
&i.IsDefault,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChatQueuedMessages = `-- name: GetChatQueuedMessages :many
SELECT id, chat_id, content, created_at FROM chat_queued_messages
WHERE chat_id = $1
@@ -5500,7 +5730,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u
const getChats = `-- name: GetChats :many
SELECT
chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context,
chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, chats.dynamic_tools,
EXISTS (
SELECT 1 FROM chat_messages cm
WHERE cm.chat_id = chats.id
@@ -5608,6 +5838,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha
&i.Chat.PinOrder,
&i.Chat.LastReadMessageID,
&i.Chat.LastInjectedContext,
&i.Chat.DynamicTools,
&i.HasUnread,
); err != nil {
return nil, err
@@ -5624,7 +5855,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha
}
const getChatsByWorkspaceIDs = `-- name: GetChatsByWorkspaceIDs :many
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM chats
WHERE archived = false
AND workspace_id = ANY($1::uuid[])
@@ -5664,6 +5895,69 @@ func (q *sqlQuerier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChatsUpdatedAfter = `-- name: GetChatsUpdatedAfter :many
SELECT
id, owner_id, created_at, updated_at, status,
(parent_chat_id IS NOT NULL)::bool AS has_parent,
root_chat_id, workspace_id,
mode, archived, last_model_config_id
FROM chats
WHERE updated_at > $1
`
type GetChatsUpdatedAfterRow struct {
ID uuid.UUID `db:"id" json:"id"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Status ChatStatus `db:"status" json:"status"`
HasParent bool `db:"has_parent" json:"has_parent"`
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
Mode NullChatMode `db:"mode" json:"mode"`
Archived bool `db:"archived" json:"archived"`
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
}
// Retrieves chats updated after the given timestamp for telemetry
// snapshot collection. Uses updated_at so that long-running chats
// still appear in each snapshot window while they are active.
func (q *sqlQuerier) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]GetChatsUpdatedAfterRow, error) {
rows, err := q.db.QueryContext(ctx, getChatsUpdatedAfter, updatedAfter)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChatsUpdatedAfterRow
for rows.Next() {
var i GetChatsUpdatedAfterRow
if err := rows.Scan(
&i.ID,
&i.OwnerID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
&i.HasParent,
&i.RootChatID,
&i.WorkspaceID,
&i.Mode,
&i.Archived,
&i.LastModelConfigID,
); err != nil {
return nil, err
}
@@ -5729,16 +6023,20 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
const getStaleChats = `-- name: GetStaleChats :many
SELECT
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM
chats
WHERE
status = 'running'::chat_status
AND heartbeat_at < $1::timestamptz
(status = 'running'::chat_status
AND heartbeat_at < $1::timestamptz)
OR (status = 'requires_action'::chat_status
AND updated_at < $1::timestamptz)
`
// Find chats that appear stuck (running but heartbeat has expired).
// Used for recovery after coderd crashes or long hangs.
// Find chats that appear stuck and need recovery. This covers:
// 1. Running chats whose heartbeat has expired (worker crash).
// 2. Chats awaiting client action (requires_action) past the
// timeout threshold (client disappeared).
func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error) {
rows, err := q.db.QueryContext(ctx, getStaleChats, staleThreshold)
if err != nil {
@@ -5772,6 +6070,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
); err != nil {
return nil, err
}
@@ -5839,7 +6138,8 @@ INSERT INTO chats (
mode,
status,
mcp_server_ids,
labels
labels,
dynamic_tools
) VALUES (
$1::uuid,
$2::uuid,
@@ -5852,10 +6152,11 @@ INSERT INTO chats (
$9::chat_mode,
$10::chat_status,
COALESCE($11::uuid[], '{}'::uuid[]),
COALESCE($12::jsonb, '{}'::jsonb)
COALESCE($12::jsonb, '{}'::jsonb),
$13::jsonb
)
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type InsertChatParams struct {
@@ -5871,6 +6172,7 @@ type InsertChatParams struct {
Status ChatStatus `db:"status" json:"status"`
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
Labels pqtype.NullRawMessage `db:"labels" json:"labels"`
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
}
func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) {
@@ -5887,6 +6189,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
arg.Status,
pq.Array(arg.MCPServerIDs),
arg.Labels,
arg.DynamicTools,
)
var i Chat
err := row.Scan(
@@ -5913,6 +6216,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -6404,16 +6708,21 @@ func (q *sqlQuerier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg Soft
const unarchiveChatByID = `-- name: UnarchiveChatByID :many
WITH chats AS (
UPDATE chats
SET archived = false, updated_at = NOW()
UPDATE chats SET
archived = false,
updated_at = NOW()
WHERE id = $1::uuid OR root_chat_id = $1::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
)
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM chats
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
`
// Unarchives a chat (and its children). Stale file references are
// handled automatically by FK cascades on chat_file_links: when
// dbpurge deletes a chat_files row, the corresponding
// chat_file_links rows are cascade-deleted by PostgreSQL.
func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error) {
rows, err := q.db.QueryContext(ctx, unarchiveChatByID, id)
if err != nil {
@@ -6447,6 +6756,7 @@ func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Cha
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
); err != nil {
return nil, err
}
@@ -6527,7 +6837,7 @@ UPDATE chats SET
updated_at = NOW()
WHERE
id = $3::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatBuildAgentBindingParams struct {
@@ -6563,6 +6873,7 @@ func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg Update
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -6576,7 +6887,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatByIDParams struct {
@@ -6611,6 +6922,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -6669,7 +6981,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatLabelsByIDParams struct {
@@ -6704,6 +7016,7 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -6713,7 +7026,7 @@ UPDATE chats SET
last_injected_context = $1::jsonb
WHERE
id = $2::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatLastInjectedContextParams struct {
@@ -6752,6 +7065,7 @@ func (q *sqlQuerier) UpdateChatLastInjectedContext(ctx context.Context, arg Upda
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -6765,7 +7079,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatLastModelConfigByIDParams struct {
@@ -6800,6 +7114,7 @@ func (q *sqlQuerier) UpdateChatLastModelConfigByID(ctx context.Context, arg Upda
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -6831,7 +7146,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatMCPServerIDsParams struct {
@@ -6866,6 +7181,7 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -7001,7 +7317,7 @@ SET
WHERE
id = $6::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatStatusParams struct {
@@ -7047,6 +7363,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -7064,7 +7381,7 @@ SET
WHERE
id = $7::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatStatusPreserveUpdatedAtParams struct {
@@ -7112,6 +7429,7 @@ func (q *sqlQuerier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -7123,7 +7441,7 @@ UPDATE chats SET
agent_id = $3::uuid,
updated_at = NOW()
WHERE id = $4::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
`
type UpdateChatWorkspaceBindingParams struct {
@@ -7165,6 +7483,7 @@ func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateC
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
)
return i, err
}
@@ -18774,6 +19093,25 @@ func (q *sqlQuerier) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (boo
return include_default_system_prompt, err
}
const getChatRetentionDays = `-- name: GetChatRetentionDays :one
SELECT COALESCE(
(SELECT value::integer FROM site_configs
WHERE key = 'agents_chat_retention_days'),
30
) :: integer AS retention_days
`
// Returns the chat retention period in days. Chats archived longer
// than this and orphaned chat files older than this are purged by
// dbpurge. Returns 30 (days) when no value has been configured.
// A value of 0 disables chat purging entirely.
func (q *sqlQuerier) GetChatRetentionDays(ctx context.Context) (int32, error) {
row := q.db.QueryRowContext(ctx, getChatRetentionDays)
var retention_days int32
err := row.Scan(&retention_days)
return retention_days, err
}
const getChatSystemPrompt = `-- name: GetChatSystemPrompt :one
SELECT
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_system_prompt'), '') :: text AS chat_system_prompt
@@ -19074,6 +19412,18 @@ func (q *sqlQuerier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, i
return err
}
const upsertChatRetentionDays = `-- name: UpsertChatRetentionDays :exec
INSERT INTO site_configs (key, value)
VALUES ('agents_chat_retention_days', CAST($1 AS integer)::text)
ON CONFLICT (key) DO UPDATE SET value = CAST($1 AS integer)::text
WHERE site_configs.key = 'agents_chat_retention_days'
`
func (q *sqlQuerier) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
_, err := q.db.ExecContext(ctx, upsertChatRetentionDays, retentionDays)
return err
}
const upsertChatSystemPrompt = `-- name: UpsertChatSystemPrompt :exec
INSERT INTO site_configs (key, value) VALUES ('agents_chat_system_prompt', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_system_prompt'
+2 -2
View File
@@ -1,8 +1,8 @@
-- name: InsertAIBridgeInterception :one
INSERT INTO aibridge_interceptions (
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id, credential_kind, credential_hint
) VALUES (
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid, @credential_kind, @credential_hint
)
RETURNING *;
+34
View File
@@ -18,3 +18,37 @@ FROM chat_files cf
JOIN chat_file_links cfl ON cfl.file_id = cf.id
WHERE cfl.chat_id = @chat_id::uuid
ORDER BY cf.created_at ASC;
-- TODO(cian): Add indexes on chats(archived, updated_at) and
-- chat_files(created_at) for purge query performance.
-- See: https://github.com/coder/internal/issues/1438
-- name: DeleteOldChatFiles :execrows
-- Deletes chat files that are older than the given threshold and are
-- not referenced by any chat that is still active or was archived
-- within the same threshold window. This covers two cases:
-- 1. Orphaned files not linked to any chat.
-- 2. Files whose every referencing chat has been archived for longer
-- than the retention period.
WITH kept_file_ids AS (
-- NOTE: This uses updated_at as a proxy for archive time
-- because there is no archived_at column. Correctness
-- requires that updated_at is never backdated on archived
-- chats. See ArchiveChatByID.
SELECT DISTINCT cfl.file_id
FROM chat_file_links cfl
JOIN chats c ON c.id = cfl.chat_id
WHERE c.archived = false
OR c.updated_at >= @before_time::timestamptz
),
deletable AS (
SELECT cf.id
FROM chat_files cf
LEFT JOIN kept_file_ids k ON cf.id = k.file_id
WHERE cf.created_at < @before_time::timestamptz
AND k.file_id IS NULL
ORDER BY cf.created_at ASC
LIMIT @limit_count
)
DELETE FROM chat_files
USING deletable
WHERE chat_files.id = deletable.id;
+81 -8
View File
@@ -10,9 +10,14 @@ FROM chats
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
-- name: UnarchiveChatByID :many
-- Unarchives a chat (and its children). Stale file references are
-- handled automatically by FK cascades on chat_file_links: when
-- dbpurge deletes a chat_files row, the corresponding
-- chat_file_links rows are cascade-deleted by PostgreSQL.
WITH chats AS (
UPDATE chats
SET archived = false, updated_at = NOW()
UPDATE chats SET
archived = false,
updated_at = NOW()
WHERE id = @id::uuid OR root_chat_id = @id::uuid
RETURNING *
)
@@ -394,7 +399,8 @@ INSERT INTO chats (
mode,
status,
mcp_server_ids,
labels
labels,
dynamic_tools
) VALUES (
@owner_id::uuid,
sqlc.narg('workspace_id')::uuid,
@@ -407,7 +413,8 @@ INSERT INTO chats (
sqlc.narg('mode')::chat_mode,
@status::chat_status,
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb)
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb),
sqlc.narg('dynamic_tools')::jsonb
)
RETURNING
*;
@@ -664,15 +671,19 @@ RETURNING
*;
-- name: GetStaleChats :many
-- Find chats that appear stuck (running but heartbeat has expired).
-- Used for recovery after coderd crashes or long hangs.
-- Find chats that appear stuck and need recovery. This covers:
-- 1. Running chats whose heartbeat has expired (worker crash).
-- 2. Chats awaiting client action (requires_action) past the
-- timeout threshold (client disappeared).
SELECT
*
FROM
chats
WHERE
status = 'running'::chat_status
AND heartbeat_at < @stale_threshold::timestamptz;
(status = 'running'::chat_status
AND heartbeat_at < @stale_threshold::timestamptz)
OR (status = 'requires_action'::chat_status
AND updated_at < @stale_threshold::timestamptz);
-- name: UpdateChatHeartbeats :many
-- Bumps the heartbeat timestamp for the given set of chat IDs,
@@ -1220,3 +1231,65 @@ LIMIT 1;
UPDATE chats
SET last_read_message_id = @last_read_message_id::bigint
WHERE id = @id::uuid;
-- name: DeleteOldChats :execrows
-- Deletes chats that have been archived for longer than the given
-- threshold. Active (non-archived) chats are never deleted.
-- Related chat_messages, chat_diff_statuses, and
-- chat_queued_messages are removed via ON DELETE CASCADE.
-- Parent/root references on child chats are SET NULL.
WITH deletable AS (
SELECT id
FROM chats
WHERE archived = true
AND updated_at < @before_time::timestamptz
ORDER BY updated_at ASC
LIMIT @limit_count
)
DELETE FROM chats
USING deletable
WHERE chats.id = deletable.id
AND chats.archived = true;
-- name: GetChatsUpdatedAfter :many
-- Retrieves chats updated after the given timestamp for telemetry
-- snapshot collection. Uses updated_at so that long-running chats
-- still appear in each snapshot window while they are active.
SELECT
id, owner_id, created_at, updated_at, status,
(parent_chat_id IS NOT NULL)::bool AS has_parent,
root_chat_id, workspace_id,
mode, archived, last_model_config_id
FROM chats
WHERE updated_at > @updated_after;
-- name: GetChatMessageSummariesPerChat :many
-- Aggregates message-level metrics per chat for messages created
-- after the given timestamp. Uses message created_at so that
-- ongoing activity in long-running chats is captured each window.
SELECT
cm.chat_id,
COUNT(*)::bigint AS message_count,
COUNT(*) FILTER (WHERE cm.role = 'user')::bigint AS user_message_count,
COUNT(*) FILTER (WHERE cm.role = 'assistant')::bigint AS assistant_message_count,
COUNT(*) FILTER (WHERE cm.role = 'tool')::bigint AS tool_message_count,
COUNT(*) FILTER (WHERE cm.role = 'system')::bigint AS system_message_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(SUM(cm.reasoning_tokens), 0)::bigint AS total_reasoning_tokens,
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms,
COUNT(DISTINCT cm.model_config_id)::bigint AS distinct_model_count,
COUNT(*) FILTER (WHERE cm.compressed)::bigint AS compressed_message_count
FROM chat_messages cm
WHERE cm.created_at > @created_after
AND cm.deleted = false
GROUP BY cm.chat_id;
-- name: GetChatModelConfigsForTelemetry :many
-- Returns all model configurations for telemetry snapshot collection.
SELECT id, provider, model, context_limit, enabled, is_default
FROM chat_model_configs
WHERE deleted = false;
+17
View File
@@ -236,3 +236,20 @@ VALUES ('agents_workspace_ttl', @workspace_ttl::text)
ON CONFLICT (key) DO UPDATE
SET value = @workspace_ttl::text
WHERE site_configs.key = 'agents_workspace_ttl';
-- name: GetChatRetentionDays :one
-- Returns the chat retention period in days. Chats archived longer
-- than this and orphaned chat files older than this are purged by
-- dbpurge. Returns 30 (days) when no value has been configured.
-- A value of 0 disables chat purging entirely.
SELECT COALESCE(
(SELECT value::integer FROM site_configs
WHERE key = 'agents_chat_retention_days'),
30
) :: integer AS retention_days;
-- name: UpsertChatRetentionDays :exec
INSERT INTO site_configs (key, value)
VALUES ('agents_chat_retention_days', CAST(@retention_days AS integer)::text)
ON CONFLICT (key) DO UPDATE SET value = CAST(@retention_days AS integer)::text
WHERE site_configs.key = 'agents_chat_retention_days';
+187
View File
@@ -398,6 +398,10 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
return
}
// Cap the raw request body to prevent excessive memory use
// from large dynamic tool schemas.
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
var req codersdk.CreateChatRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
@@ -488,6 +492,50 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
return
}
if len(req.UnsafeDynamicTools) > 250 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Too many dynamic tools.",
Detail: "Maximum 250 dynamic tools per chat.",
})
return
}
// Validate that dynamic tool names are non-empty and unique
// within the list. Name collision with built-in tools is
// checked at chatloop time when the full tool set is known.
if len(req.UnsafeDynamicTools) > 0 {
seenNames := make(map[string]struct{}, len(req.UnsafeDynamicTools))
for _, dt := range req.UnsafeDynamicTools {
if dt.Name == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Dynamic tool name must not be empty.",
})
return
}
if _, exists := seenNames[dt.Name]; exists {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Duplicate dynamic tool name.",
Detail: fmt.Sprintf("Tool %q appears more than once.", dt.Name),
})
return
}
seenNames[dt.Name] = struct{}{}
}
}
var dynamicToolsJSON json.RawMessage
if len(req.UnsafeDynamicTools) > 0 {
var err error
dynamicToolsJSON, err = json.Marshal(req.UnsafeDynamicTools)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to marshal dynamic tools.",
Detail: err.Error(),
})
return
}
}
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
OwnerID: apiKey.UserID,
WorkspaceID: workspaceSelection.WorkspaceID,
@@ -497,6 +545,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
InitialUserContent: contentBlocks,
MCPServerIDs: mcpServerIDs,
Labels: labels,
DynamicTools: dynamicToolsJSON,
})
if err != nil {
if maybeWriteLimitErr(ctx, rw, err) {
@@ -3183,6 +3232,70 @@ func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNoContent)
}
// @Summary Get chat retention days
// @ID get-chat-retention-days
// @Security CoderSessionToken
// @Tags Chats
// @Produce json
// @Success 200 {object} codersdk.ChatRetentionDaysResponse
// @Router /experimental/chats/config/retention-days [get]
// @x-apidocgen {"skip": true}
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
func (api *API) getChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
retentionDays, err := api.Database.GetChatRetentionDays(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get chat retention days.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatRetentionDaysResponse{
RetentionDays: retentionDays,
})
}
// Keep in sync with retentionDaysMaximum in
// site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx.
const retentionDaysMaximum = 3650 // ~10 years
// @Summary Update chat retention days
// @ID update-chat-retention-days
// @Security CoderSessionToken
// @Tags Chats
// @Accept json
// @Param request body codersdk.UpdateChatRetentionDaysRequest true "Request body"
// @Success 204
// @Router /experimental/chats/config/retention-days [put]
// @x-apidocgen {"skip": true}
func (api *API) putChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
var req codersdk.UpdateChatRetentionDaysRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if req.RetentionDays < 0 || req.RetentionDays > retentionDaysMaximum {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Retention days must be between 0 and %d.", retentionDaysMaximum),
})
return
}
if err := api.Database.UpsertChatRetentionDays(ctx, req.RetentionDays); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update chat retention days.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
@@ -5687,3 +5800,77 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
RecentPRs: prEntries,
})
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) postChatToolResults(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
apiKey := httpmw.APIKey(r)
// Cap the raw request body to prevent excessive memory use.
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
var req codersdk.SubmitToolResultsRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if len(req.Results) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "At least one tool result is required.",
})
return
}
// Fast-path check outside the transaction. The authoritative
// check happens inside SubmitToolResults under a row lock.
if chat.Status != database.ChatStatusRequiresAction {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Chat is not waiting for tool results.",
Detail: fmt.Sprintf("Chat status is %q, expected %q.", chat.Status, database.ChatStatusRequiresAction),
})
return
}
var dynamicTools json.RawMessage
if chat.DynamicTools.Valid {
dynamicTools = chat.DynamicTools.RawMessage
}
err := api.chatDaemon.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: apiKey.UserID,
ModelConfigID: chat.LastModelConfigID,
Results: req.Results,
DynamicTools: dynamicTools,
})
if err != nil {
var validationErr *chatd.ToolResultValidationError
var conflictErr *chatd.ToolResultStatusConflictError
switch {
case errors.As(err, &conflictErr):
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Chat is not waiting for tool results.",
Detail: err.Error(),
})
case errors.As(err, &validationErr):
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: validationErr.Message,
Detail: validationErr.Detail,
})
default:
api.Logger.Error(ctx, "tool results submission failed",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error submitting tool results.",
})
}
return
}
rw.WriteHeader(http.StatusNoContent)
}
+431 -18
View File
@@ -16,7 +16,9 @@ import (
"time"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
"github.com/shopspring/decimal"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
@@ -268,17 +270,10 @@ func TestPostChats(t *testing.T) {
_ = createChatModelConfig(t, client)
// Member without agents-access should be denied.
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
// Strip the auto-assigned agents-access role to test
// the denied case.
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
Roles: []string{},
})
require.NoError(t, err)
_, err = memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
_, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
@@ -288,6 +283,7 @@ func TestPostChats(t *testing.T) {
})
requireSDKError(t, err, http.StatusForbidden)
})
t.Run("HidesSystemPromptMessages", func(t *testing.T) {
t.Parallel()
@@ -756,15 +752,7 @@ func TestListChats(t *testing.T) {
// returning empty because no chats exist.
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
// Strip the auto-assigned agents-access role to test
// the denied case.
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
Roles: []string{},
})
require.NoError(t, err)
_, err = db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
_, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: member.ID,
LastModelConfigID: modelConfig.ID,
@@ -7747,6 +7735,62 @@ func TestChatWorkspaceTTL(t *testing.T) {
requireSDKError(t, err, http.StatusBadRequest)
}
func TestChatRetentionDays(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
// Default value is 30 (days) when nothing has been configured.
resp, err := adminClient.GetChatRetentionDays(ctx)
require.NoError(t, err, "get default")
require.Equal(t, int32(30), resp.RetentionDays, "default should be 30")
// Admin can set retention days to 90.
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
RetentionDays: 90,
})
require.NoError(t, err, "admin set 90")
resp, err = adminClient.GetChatRetentionDays(ctx)
require.NoError(t, err, "get after set")
require.Equal(t, int32(90), resp.RetentionDays, "should return 90")
// Non-admin member can read the value.
resp, err = memberClient.GetChatRetentionDays(ctx)
require.NoError(t, err, "member get")
require.Equal(t, int32(90), resp.RetentionDays, "member should see same value")
// Non-admin member cannot write.
err = memberClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{RetentionDays: 7})
requireSDKError(t, err, http.StatusForbidden)
// Admin can disable purge by setting 0.
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
RetentionDays: 0,
})
require.NoError(t, err, "admin set 0")
resp, err = adminClient.GetChatRetentionDays(ctx)
require.NoError(t, err, "get after zero")
require.Equal(t, int32(0), resp.RetentionDays, "should be 0 after disable")
// Validation: negative value is rejected.
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
RetentionDays: -1,
})
requireSDKError(t, err, http.StatusBadRequest)
// Validation: exceeding the 3650-day maximum is rejected.
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
RetentionDays: 3651, // retentionDaysMaximum + 1; keep in sync with coderd/exp_chats.go.
})
requireSDKError(t, err, http.StatusBadRequest)
}
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
func TestUserChatCompactionThresholds(t *testing.T) {
t.Parallel()
@@ -8153,6 +8197,375 @@ func TestGetChatsByWorkspace(t *testing.T) {
})
}
func TestSubmitToolResults(t *testing.T) {
t.Parallel()
// setupRequiresAction creates a chat via the DB with dynamic tools,
// inserts an assistant message containing tool-call parts for each
// given toolCallID, and sets the chat status to requires_action.
// It returns the chat row so callers can exercise the endpoint.
setupRequiresAction := func(
ctx context.Context,
t *testing.T,
db database.Store,
ownerID uuid.UUID,
modelConfigID uuid.UUID,
dynamicToolName string,
toolCallIDs []string,
) database.Chat {
t.Helper()
// Marshal dynamic tools into the chat row.
dynamicTools := []mcp.Tool{{
Name: dynamicToolName,
Description: "a test dynamic tool",
InputSchema: mcp.ToolInputSchema{Type: "object"},
}}
dtJSON, err := json.Marshal(dynamicTools)
require.NoError(t, err)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: ownerID,
LastModelConfigID: modelConfigID,
Title: "tool-results-test",
DynamicTools: pqtype.NullRawMessage{RawMessage: dtJSON, Valid: true},
})
require.NoError(t, err)
// Build assistant message with tool-call parts.
parts := make([]codersdk.ChatMessagePart, 0, len(toolCallIDs))
for _, id := range toolCallIDs {
parts = append(parts, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: id,
ToolName: dynamicToolName,
Args: json.RawMessage(`{"key":"value"}`),
})
}
content, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfigID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(content.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
// Transition to requires_action.
chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRequiresAction,
})
require.NoError(t, err)
require.Equal(t, database.ChatStatusRequiresAction, chat.Status)
return chat
}
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_abc", "call_def"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_abc", Output: json.RawMessage(`"result_a"`)},
{ToolCallID: "call_def", Output: json.RawMessage(`"result_b"`)},
},
})
require.NoError(t, err)
// Verify status is no longer requires_action. The chatd
// loop may have already picked the chat up and
// transitioned it further (pending → running → …), so we
// accept any non-requires_action status.
gotChat, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
require.NotEqual(t, codersdk.ChatStatusRequiresAction, gotChat.Status,
"chat should no longer be in requires_action after submitting tool results")
// Verify tool-result messages were persisted.
msgsResp, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
var toolResultCount int
for _, msg := range msgsResp.Messages {
if msg.Role == codersdk.ChatMessageRoleTool {
toolResultCount++
}
}
require.Equal(t, len(toolCallIDs), toolResultCount,
"expected one tool-result message per submitted result")
})
t.Run("WrongStatus", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
// Create a chat that is NOT in requires_action status.
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "wrong-status-test",
})
require.NoError(t, err)
err = client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_xyz", Output: json.RawMessage(`"nope"`)},
},
})
requireSDKError(t, err, http.StatusConflict)
})
t.Run("MissingResult", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_one", "call_two"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
// Submit only one of the two required results.
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_one", Output: json.RawMessage(`"partial"`)},
},
})
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("UnexpectedResult", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_real"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
// Submit a result with a wrong tool_call_id.
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_bogus", Output: json.RawMessage(`"wrong"`)},
},
})
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("InvalidJSONOutput", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_json"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
// We must bypass the SDK client because json.RawMessage
// rejects invalid JSON during json.Marshal. A raw HTTP
// request lets the invalid payload reach the server so we
// can verify server-side validation.
rawBody := `{"results":[{"tool_call_id":"call_json","output":not-json,"is_error":false}]}`
url := client.URL.JoinPath(fmt.Sprintf("/api/experimental/chats/%s/tool-results", chat.ID)).String()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBufferString(rawBody))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("DuplicateToolCallID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_dup1", "call_dup2"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_a"`)},
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_b"`)},
},
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Contains(t, sdkErr.Message, "Duplicate tool_call_id")
})
t.Run("EmptyResults", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_empty"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{},
})
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("NotFoundForDifferentUser", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_other"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
// Create a second user and try to submit tool results
// to user A's chat.
otherClientRaw, _ := coderdtest.CreateAnotherUser(
t, client.Client, user.OrganizationID,
rbac.RoleAgentsAccess(),
)
otherClient := codersdk.NewExperimentalClient(otherClientRaw)
err := otherClient.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_other", Output: json.RawMessage(`"nope"`)},
},
})
requireSDKError(t, err, http.StatusNotFound)
})
}
func TestPostChats_DynamicToolValidation(t *testing.T) {
t.Parallel()
t.Run("TooManyTools", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
tools := make([]codersdk.DynamicTool, 251)
for i := range tools {
tools[i] = codersdk.DynamicTool{
Name: fmt.Sprintf("tool-%d", i),
}
}
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
}},
UnsafeDynamicTools: tools,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Too many dynamic tools.", sdkErr.Message)
})
t.Run("EmptyToolName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
}},
UnsafeDynamicTools: []codersdk.DynamicTool{
{Name: ""},
},
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Dynamic tool name must not be empty.", sdkErr.Message)
})
t.Run("DuplicateToolName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
}},
UnsafeDynamicTools: []codersdk.DynamicTool{
{Name: "dup-tool"},
{Name: "dup-tool"},
},
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Duplicate dynamic tool name.", sdkErr.Message)
})
}
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
t.Helper()
+1 -1
View File
@@ -148,7 +148,7 @@ func TestGetOrgMembersFilter(t *testing.T) {
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
res, err := client.OrganizationMembersPaginated(testCtx, first.OrganizationID, req)
require.NoError(t, err)
reduced := make([]codersdk.ReducedUser, len(res.Members))
+4 -2
View File
@@ -32,8 +32,9 @@ func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error))
}
type ChatEvent struct {
Kind ChatEventKind `json:"kind"`
Chat codersdk.Chat `json:"chat"`
Kind ChatEventKind `json:"kind"`
Chat codersdk.Chat `json:"chat"`
ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"`
}
type ChatEventKind string
@@ -44,4 +45,5 @@ const (
ChatEventKindCreated ChatEventKind = "created"
ChatEventKindDeleted ChatEventKind = "deleted"
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
ChatEventKindActionRequired ChatEventKind = "action_required"
)
+17 -35
View File
@@ -19,7 +19,6 @@ import (
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"
@@ -390,7 +389,6 @@ type MultiAgentController struct {
// connections to the destination
tickets map[uuid.UUID]map[uuid.UUID]struct{}
coordination *tailnet.BasicCoordination
sendGroup singleflight.Group
cancel context.CancelFunc
expireOldAgentsDone chan struct{}
@@ -420,44 +418,28 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.connectionTimes[agentID]
if ok {
m.connectionTimes[agentID] = time.Now()
m.mu.Unlock()
return nil
}
m.mu.Unlock()
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
_, err, _ := m.sendGroup.Do(agentID.String(), func() (interface{}, error) {
m.mu.Lock()
coord := m.coordination
m.mu.Unlock()
if coord == nil {
return nil, xerrors.New("no active coordination")
// If we don't have the agent, subscribe.
if !ok {
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
if m.coordination != nil {
err := m.coordination.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
err = xerrors.Errorf("subscribe agent: %w", err)
m.coordination.SendErr(err)
_ = m.coordination.Client.Close()
m.coordination = nil
return err
}
}
err := coord.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
return nil, err
}
m.mu.Lock()
m.tickets[agentID] = map[uuid.UUID]struct{}{}
m.mu.Unlock()
return nil, nil
})
if err != nil {
m.logger.Error(context.Background(), "ensureAgent send failed",
slog.F("agent_id", agentID), slog.Error(err))
return xerrors.Errorf("send AddTunnel: %w", err)
}
m.mu.Lock()
m.connectionTimes[agentID] = time.Now()
m.mu.Unlock()
return nil
}
+144
View File
@@ -776,6 +776,40 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
return nil
})
eg.Go(func() error {
chats, err := r.options.Database.GetChatsUpdatedAfter(ctx, createdAfter)
if err != nil {
return xerrors.Errorf("get chats updated after: %w", err)
}
snapshot.Chats = make([]Chat, 0, len(chats))
for _, chat := range chats {
snapshot.Chats = append(snapshot.Chats, ConvertChat(chat))
}
return nil
})
eg.Go(func() error {
summaries, err := r.options.Database.GetChatMessageSummariesPerChat(ctx, createdAfter)
if err != nil {
return xerrors.Errorf("get chat message summaries: %w", err)
}
snapshot.ChatMessageSummaries = make([]ChatMessageSummary, 0, len(summaries))
for _, s := range summaries {
snapshot.ChatMessageSummaries = append(snapshot.ChatMessageSummaries, ConvertChatMessageSummary(s))
}
return nil
})
eg.Go(func() error {
configs, err := r.options.Database.GetChatModelConfigsForTelemetry(ctx)
if err != nil {
return xerrors.Errorf("get chat model configs: %w", err)
}
snapshot.ChatModelConfigs = make([]ChatModelConfig, 0, len(configs))
for _, c := range configs {
snapshot.ChatModelConfigs = append(snapshot.ChatModelConfigs, ConvertChatModelConfig(c))
}
return nil
})
err := eg.Wait()
if err != nil {
return nil, err
@@ -1503,6 +1537,9 @@ type Snapshot struct {
AIBridgeInterceptionsSummaries []AIBridgeInterceptionsSummary `json:"aibridge_interceptions_summaries"`
BoundaryUsageSummary *BoundaryUsageSummary `json:"boundary_usage_summary"`
FirstUserOnboarding *FirstUserOnboarding `json:"first_user_onboarding"`
Chats []Chat `json:"chats"`
ChatMessageSummaries []ChatMessageSummary `json:"chat_message_summaries"`
ChatModelConfigs []ChatModelConfig `json:"chat_model_configs"`
}
// Deployment contains information about the host running Coder.
@@ -2113,6 +2150,66 @@ func ConvertTask(task database.Task) Task {
return t
}
// ConvertChat converts a database chat row to a telemetry Chat.
func ConvertChat(dbChat database.GetChatsUpdatedAfterRow) Chat {
c := Chat{
ID: dbChat.ID,
OwnerID: dbChat.OwnerID,
CreatedAt: dbChat.CreatedAt,
UpdatedAt: dbChat.UpdatedAt,
Status: string(dbChat.Status),
HasParent: dbChat.HasParent,
Archived: dbChat.Archived,
LastModelConfigID: dbChat.LastModelConfigID,
}
if dbChat.RootChatID.Valid {
c.RootChatID = &dbChat.RootChatID.UUID
}
if dbChat.WorkspaceID.Valid {
c.WorkspaceID = &dbChat.WorkspaceID.UUID
}
if dbChat.Mode.Valid {
mode := string(dbChat.Mode.ChatMode)
c.Mode = &mode
}
return c
}
// ConvertChatMessageSummary converts a database chat message
// summary row to a telemetry ChatMessageSummary.
func ConvertChatMessageSummary(dbRow database.GetChatMessageSummariesPerChatRow) ChatMessageSummary {
return ChatMessageSummary{
ChatID: dbRow.ChatID,
MessageCount: dbRow.MessageCount,
UserMessageCount: dbRow.UserMessageCount,
AssistantMessageCount: dbRow.AssistantMessageCount,
ToolMessageCount: dbRow.ToolMessageCount,
SystemMessageCount: dbRow.SystemMessageCount,
TotalInputTokens: dbRow.TotalInputTokens,
TotalOutputTokens: dbRow.TotalOutputTokens,
TotalReasoningTokens: dbRow.TotalReasoningTokens,
TotalCacheCreationTokens: dbRow.TotalCacheCreationTokens,
TotalCacheReadTokens: dbRow.TotalCacheReadTokens,
TotalCostMicros: dbRow.TotalCostMicros,
TotalRuntimeMs: dbRow.TotalRuntimeMs,
DistinctModelCount: dbRow.DistinctModelCount,
CompressedMessageCount: dbRow.CompressedMessageCount,
}
}
// ConvertChatModelConfig converts a database model config row to a
// telemetry ChatModelConfig.
func ConvertChatModelConfig(dbRow database.GetChatModelConfigsForTelemetryRow) ChatModelConfig {
return ChatModelConfig{
ID: dbRow.ID,
Provider: dbRow.Provider,
Model: dbRow.Model,
ContextLimit: dbRow.ContextLimit,
Enabled: dbRow.Enabled,
IsDefault: dbRow.IsDefault,
}
}
type telemetryItemKey string
// The comment below gets rid of the warning that the name "TelemetryItemKey" has
@@ -2234,6 +2331,53 @@ type BoundaryUsageSummary struct {
PeriodDurationMilliseconds int64 `json:"period_duration_ms"`
}
// Chat contains anonymized metadata about a chat for telemetry.
// Titles and message content are excluded to avoid PII leakage.
type Chat struct {
ID uuid.UUID `json:"id"`
OwnerID uuid.UUID `json:"owner_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Status string `json:"status"`
HasParent bool `json:"has_parent"`
RootChatID *uuid.UUID `json:"root_chat_id"`
WorkspaceID *uuid.UUID `json:"workspace_id"`
Mode *string `json:"mode"`
Archived bool `json:"archived"`
LastModelConfigID uuid.UUID `json:"last_model_config_id"`
}
// ChatMessageSummary contains per-chat aggregated message metrics
// for telemetry. Individual message content is never included.
type ChatMessageSummary struct {
ChatID uuid.UUID `json:"chat_id"`
MessageCount int64 `json:"message_count"`
UserMessageCount int64 `json:"user_message_count"`
AssistantMessageCount int64 `json:"assistant_message_count"`
ToolMessageCount int64 `json:"tool_message_count"`
SystemMessageCount int64 `json:"system_message_count"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalReasoningTokens int64 `json:"total_reasoning_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalCostMicros int64 `json:"total_cost_micros"`
TotalRuntimeMs int64 `json:"total_runtime_ms"`
DistinctModelCount int64 `json:"distinct_model_count"`
CompressedMessageCount int64 `json:"compressed_message_count"`
}
// ChatModelConfig contains model configuration metadata for
// telemetry. Sensitive fields like API keys are excluded.
type ChatModelConfig struct {
ID uuid.UUID `json:"id"`
Provider string `json:"provider"`
Model string `json:"model"`
ContextLimit int64 `json:"context_limit"`
Enabled bool `json:"enabled"`
IsDefault bool `json:"is_default"`
}
func ConvertAIBridgeInterceptionsSummary(endTime time.Time, provider, model, client string, summary database.CalculateAIBridgeInterceptionsTelemetrySummaryRow) AIBridgeInterceptionsSummary {
return AIBridgeInterceptionsSummary{
ID: uuid.New(),
+300
View File
@@ -1549,3 +1549,303 @@ func TestTelemetry_BoundaryUsageSummary(t *testing.T) {
require.Nil(t, snapshot2.BoundaryUsageSummary)
})
}
func TestChatsTelemetry(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
// Create chat providers (required FK for model configs).
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "anthropic",
DisplayName: "Anthropic",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
// Create a model config.
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "anthropic",
Model: "claude-sonnet-4-20250514",
DisplayName: "Claude Sonnet",
Enabled: true,
IsDefault: true,
ContextLimit: 200000,
CompressionThreshold: 70,
Options: json.RawMessage("{}"),
})
require.NoError(t, err)
// Create a second model config to test full dump.
modelCfg2, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o",
DisplayName: "GPT-4o",
Enabled: true,
IsDefault: false,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage("{}"),
})
require.NoError(t, err)
// Create a soft-deleted model config — should NOT appear in telemetry.
deletedCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "anthropic",
Model: "claude-deleted",
DisplayName: "Deleted Model",
Enabled: true,
IsDefault: false,
ContextLimit: 100000,
CompressionThreshold: 70,
Options: json.RawMessage("{}"),
})
require.NoError(t, err)
err = db.DeleteChatModelConfigByID(ctx, deletedCfg.ID)
require.NoError(t, err)
// Create a root chat with a workspace.
org, err := db.GetDefaultOrganization(ctx)
require.NoError(t, err)
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
CreatedBy: user.ID,
JobID: job.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: tpl.ID,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
WorkspaceID: ws.ID,
TemplateVersionID: tv.ID,
JobID: job.ID,
})
rootChat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "Root Chat",
Status: database.ChatStatusRunning,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
Mode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
})
require.NoError(t, err)
// Create a child chat (has parent + root).
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
LastModelConfigID: modelCfg2.ID,
Title: "Child Chat",
Status: database.ChatStatusCompleted,
ParentChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
})
require.NoError(t, err)
// Insert messages for root chat: 2 user, 2 assistant, 1 tool.
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: rootChat.ID,
CreatedBy: []uuid.UUID{user.ID, uuid.Nil, user.ID, uuid.Nil, uuid.Nil},
ModelConfigID: []uuid.UUID{modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
Content: []string{`[{"type":"text","text":"hello"}]`, `[{"type":"text","text":"hi"}]`, `[{"type":"text","text":"help"}]`, `[{"type":"text","text":"sure"}]`, `[{"type":"text","text":"result"}]`},
ContentVersion: []int16{1, 1, 1, 1, 1},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
InputTokens: []int64{100, 200, 150, 300, 0},
OutputTokens: []int64{0, 50, 0, 100, 0},
TotalTokens: []int64{100, 250, 150, 400, 0},
ReasoningTokens: []int64{0, 10, 0, 20, 0},
CacheCreationTokens: []int64{50, 0, 30, 0, 0},
CacheReadTokens: []int64{0, 25, 0, 40, 0},
ContextLimit: []int64{200000, 200000, 200000, 200000, 200000},
Compressed: []bool{false, false, false, false, false},
TotalCostMicros: []int64{1000, 2000, 1500, 3000, 0},
RuntimeMs: []int64{0, 500, 0, 800, 100},
ProviderResponseID: []string{"", "resp-1", "", "resp-2", ""},
})
require.NoError(t, err)
// Insert messages for child chat: 1 user, 1 assistant (compressed).
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: childChat.ID,
CreatedBy: []uuid.UUID{user.ID, uuid.Nil},
ModelConfigID: []uuid.UUID{modelCfg2.ID, modelCfg2.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant},
Content: []string{`[{"type":"text","text":"q"}]`, `[{"type":"text","text":"a"}]`},
ContentVersion: []int16{1, 1},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
InputTokens: []int64{500, 600},
OutputTokens: []int64{0, 200},
TotalTokens: []int64{500, 800},
ReasoningTokens: []int64{0, 50},
CacheCreationTokens: []int64{100, 0},
CacheReadTokens: []int64{0, 75},
ContextLimit: []int64{128000, 128000},
Compressed: []bool{false, true},
TotalCostMicros: []int64{5000, 8000},
RuntimeMs: []int64{0, 1200},
ProviderResponseID: []string{"", "resp-3"},
})
require.NoError(t, err)
// Insert a soft-deleted message on root chat with large token values.
// This acts as "poison" — if the deleted filter is missing, totals
// will be inflated and assertions below will fail.
poisonMsgs, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: rootChat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelCfg.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
Content: []string{`[{"type":"text","text":"poison"}]`},
ContentVersion: []int16{1},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{999999},
OutputTokens: []int64{999999},
TotalTokens: []int64{999999},
ReasoningTokens: []int64{999999},
CacheCreationTokens: []int64{999999},
CacheReadTokens: []int64{999999},
ContextLimit: []int64{200000},
Compressed: []bool{false},
TotalCostMicros: []int64{999999},
RuntimeMs: []int64{999999},
ProviderResponseID: []string{""},
})
require.NoError(t, err)
err = db.SoftDeleteChatMessageByID(ctx, poisonMsgs[0].ID)
require.NoError(t, err)
_, snapshot := collectSnapshot(ctx, t, db, nil)
// --- Assert Chats ---
require.Len(t, snapshot.Chats, 2)
// Find root and child by HasParent flag.
var foundRoot, foundChild *telemetry.Chat
for i := range snapshot.Chats {
if !snapshot.Chats[i].HasParent {
foundRoot = &snapshot.Chats[i]
} else {
foundChild = &snapshot.Chats[i]
}
}
require.NotNil(t, foundRoot, "expected root chat")
require.NotNil(t, foundChild, "expected child chat")
// Root chat assertions.
assert.Equal(t, rootChat.ID, foundRoot.ID)
assert.Equal(t, user.ID, foundRoot.OwnerID)
assert.Equal(t, "running", foundRoot.Status)
assert.False(t, foundRoot.HasParent)
assert.Nil(t, foundRoot.RootChatID)
require.NotNil(t, foundRoot.WorkspaceID)
assert.Equal(t, ws.ID, *foundRoot.WorkspaceID)
assert.Equal(t, modelCfg.ID, foundRoot.LastModelConfigID)
require.NotNil(t, foundRoot.Mode)
assert.Equal(t, "computer_use", *foundRoot.Mode)
assert.False(t, foundRoot.Archived)
// Child chat assertions.
assert.Equal(t, childChat.ID, foundChild.ID)
assert.Equal(t, user.ID, foundChild.OwnerID)
assert.True(t, foundChild.HasParent)
require.NotNil(t, foundChild.RootChatID)
assert.Equal(t, rootChat.ID, *foundChild.RootChatID)
assert.Nil(t, foundChild.WorkspaceID)
assert.Equal(t, "completed", foundChild.Status)
assert.Equal(t, modelCfg2.ID, foundChild.LastModelConfigID)
assert.Nil(t, foundChild.Mode)
assert.False(t, foundChild.Archived)
// --- Assert ChatMessageSummaries ---
require.Len(t, snapshot.ChatMessageSummaries, 2)
summaryMap := make(map[uuid.UUID]telemetry.ChatMessageSummary)
for _, s := range snapshot.ChatMessageSummaries {
summaryMap[s.ChatID] = s
}
// Root chat summary: 2 user + 2 assistant + 1 tool = 5 messages.
rootSummary, ok := summaryMap[rootChat.ID]
require.True(t, ok, "expected summary for root chat")
assert.Equal(t, int64(5), rootSummary.MessageCount)
assert.Equal(t, int64(2), rootSummary.UserMessageCount)
assert.Equal(t, int64(2), rootSummary.AssistantMessageCount)
assert.Equal(t, int64(1), rootSummary.ToolMessageCount)
assert.Equal(t, int64(0), rootSummary.SystemMessageCount)
assert.Equal(t, int64(750), rootSummary.TotalInputTokens) // 100+200+150+300+0
assert.Equal(t, int64(150), rootSummary.TotalOutputTokens) // 0+50+0+100+0
assert.Equal(t, int64(30), rootSummary.TotalReasoningTokens) // 0+10+0+20+0
assert.Equal(t, int64(80), rootSummary.TotalCacheCreationTokens) // 50+0+30+0+0
assert.Equal(t, int64(65), rootSummary.TotalCacheReadTokens) // 0+25+0+40+0
assert.Equal(t, int64(7500), rootSummary.TotalCostMicros) // 1000+2000+1500+3000+0
assert.Equal(t, int64(1400), rootSummary.TotalRuntimeMs) // 0+500+0+800+100
assert.Equal(t, int64(1), rootSummary.DistinctModelCount)
assert.Equal(t, int64(0), rootSummary.CompressedMessageCount)
// Child chat summary: 1 user + 1 assistant = 2 messages, 1 compressed.
childSummary, ok := summaryMap[childChat.ID]
require.True(t, ok, "expected summary for child chat")
assert.Equal(t, int64(2), childSummary.MessageCount)
assert.Equal(t, int64(1), childSummary.UserMessageCount)
assert.Equal(t, int64(1), childSummary.AssistantMessageCount)
assert.Equal(t, int64(1100), childSummary.TotalInputTokens) // 500+600
assert.Equal(t, int64(200), childSummary.TotalOutputTokens) // 0+200
assert.Equal(t, int64(50), childSummary.TotalReasoningTokens) // 0+50
assert.Equal(t, int64(0), childSummary.ToolMessageCount)
assert.Equal(t, int64(0), childSummary.SystemMessageCount)
assert.Equal(t, int64(100), childSummary.TotalCacheCreationTokens) // 100+0
assert.Equal(t, int64(75), childSummary.TotalCacheReadTokens) // 0+75
assert.Equal(t, int64(13000), childSummary.TotalCostMicros) // 5000+8000
assert.Equal(t, int64(1200), childSummary.TotalRuntimeMs) // 0+1200
assert.Equal(t, int64(1), childSummary.DistinctModelCount)
assert.Equal(t, int64(1), childSummary.CompressedMessageCount)
// --- Assert ChatModelConfigs ---
require.Len(t, snapshot.ChatModelConfigs, 2)
configMap := make(map[uuid.UUID]telemetry.ChatModelConfig)
for _, c := range snapshot.ChatModelConfigs {
configMap[c.ID] = c
}
cfg1, ok := configMap[modelCfg.ID]
require.True(t, ok)
assert.Equal(t, "anthropic", cfg1.Provider)
assert.Equal(t, "claude-sonnet-4-20250514", cfg1.Model)
assert.Equal(t, int64(200000), cfg1.ContextLimit)
assert.True(t, cfg1.Enabled)
assert.True(t, cfg1.IsDefault)
cfg2, ok := configMap[modelCfg2.ID]
require.True(t, ok)
assert.Equal(t, "openai", cfg2.Provider)
assert.Equal(t, "gpt-4o", cfg2.Model)
assert.Equal(t, int64(128000), cfg2.ContextLimit)
assert.True(t, cfg2.Enabled)
assert.False(t, cfg2.IsDefault)
}
+8 -12
View File
@@ -475,6 +475,14 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
}
req.UserLoginType = codersdk.LoginTypeNone
// Service accounts are a Premium feature.
if !api.Entitlements.Enabled(codersdk.FeatureServiceAccounts) {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: fmt.Sprintf("%s is a Premium feature. Contact sales!", codersdk.FeatureServiceAccounts.Humanize()),
})
return
}
} else if req.UserLoginType == "" {
// Default to password auth
req.UserLoginType = codersdk.LoginTypePassword
@@ -1630,18 +1638,6 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
rbacRoles = req.RBACRoles
}
// When the agents experiment is enabled, auto-assign the
// agents-access role so new users can use Coder Agents
// without manual admin intervention. Skip this for OIDC
// users when site role sync is enabled, because the sync
// will overwrite roles on every login anyway — those
// admins should use --oidc-user-role-default instead.
if api.Experiments.Enabled(codersdk.ExperimentAgents) &&
!(req.LoginType == database.LoginTypeOIDC && api.IDPSync.SiteRoleSyncEnabled()) &&
!slices.Contains(rbacRoles, codersdk.RoleAgentsAccess) {
rbacRoles = append(rbacRoles, codersdk.RoleAgentsAccess)
}
var user database.User
err := store.InTx(func(tx database.Store) error {
orgRoles := make([]string, 0)
+5 -142
View File
@@ -829,35 +829,6 @@ func TestPostUsers(t *testing.T) {
assert.Equal(t, firstUser.OrganizationID, user.OrganizationIDs[0])
})
// CreateWithAgentsExperiment verifies that new users
// are auto-assigned the agents-access role when the
// experiment is enabled. The experiment-disabled case
// is implicitly covered by TestInitialRoles, which
// asserts exactly [owner] with no experiment — it
// would fail if agents-access leaked through.
t.Run("CreateWithAgentsExperiment", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentAgents)}
client := coderdtest.New(t, &coderdtest.Options{DeploymentValues: dv})
firstUser := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{firstUser.OrganizationID},
Email: "another@user.org",
Username: "someone-else",
Password: "SomeSecurePassword!",
})
require.NoError(t, err)
roles, err := client.UserRoles(ctx, user.Username)
require.NoError(t, err)
require.Contains(t, roles.Roles, codersdk.RoleAgentsAccess,
"new user should have agents-access role when agents experiment is enabled")
})
t.Run("CreateWithStatus", func(t *testing.T) {
t.Parallel()
auditor := audit.NewMock()
@@ -979,7 +950,7 @@ func TestPostUsers(t *testing.T) {
require.Equal(t, found.LoginType, codersdk.LoginTypeOIDC)
})
t.Run("ServiceAccount/OK", func(t *testing.T) {
t.Run("ServiceAccount/Unlicensed", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
@@ -987,98 +958,16 @@ func TestPostUsers(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-ok",
UserLoginType: codersdk.LoginTypeNone,
ServiceAccount: true,
})
require.NoError(t, err)
require.Equal(t, codersdk.LoginTypeNone, user.LoginType)
require.Empty(t, user.Email)
require.Equal(t, "service-acct-ok", user.Username)
require.Equal(t, codersdk.UserStatusDormant, user.Status)
})
t.Run("ServiceAccount/WithEmail", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-email",
Email: "should-not-have@email.com",
ServiceAccount: true,
})
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "Email cannot be set for service accounts")
})
t.Run("ServiceAccount/WithPassword", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-password",
Password: "ShouldNotHavePassword123!",
ServiceAccount: true,
})
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "Password cannot be set for service accounts")
})
t.Run("ServiceAccount/WithInvalidLoginType", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-login-type",
UserLoginType: codersdk.LoginTypePassword,
ServiceAccount: true,
})
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "Service accounts must use login type 'none'")
})
t.Run("ServiceAccount/DefaultLoginType", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-default-login",
ServiceAccount: true,
})
require.NoError(t, err)
found, err := client.User(ctx, user.ID.String())
require.NoError(t, err)
require.Equal(t, codersdk.LoginTypeNone, found.LoginType)
require.Empty(t, found.Email)
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "Premium feature")
})
t.Run("NonServiceAccount/WithoutEmail", func(t *testing.T) {
@@ -1098,32 +987,6 @@ func TestPostUsers(t *testing.T) {
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
})
t.Run("ServiceAccount/MultipleWithoutEmail", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
user1, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-multi-1",
ServiceAccount: true,
})
require.NoError(t, err)
require.Empty(t, user1.Email)
user2, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-multi-2",
ServiceAccount: true,
})
require.NoError(t, err)
require.Empty(t, user2.Email)
require.NotEqual(t, user1.ID, user2.ID)
})
}
func TestNotifyCreatedUser(t *testing.T) {
@@ -1832,7 +1695,7 @@ func TestGetUsersFilter(t *testing.T) {
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
res, err := client.Users(testCtx, req)
require.NoError(t, err)
reduced := make([]codersdk.ReducedUser, len(res.Users))
+3 -2
View File
@@ -181,8 +181,9 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
level := make([]database.LogLevel, 0)
outputLength := 0
for _, logEntry := range req.Logs {
output = append(output, logEntry.Output)
outputLength += len(logEntry.Output)
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
output = append(output, sanitizedOutput)
outputLength += len(sanitizedOutput)
if logEntry.Level == "" {
// Default to "info" to support older agents that didn't have the level field.
logEntry.Level = codersdk.LogLevelInfo
+44
View File
@@ -260,6 +260,50 @@ func TestWorkspaceAgentLogs(t *testing.T) {
require.Equal(t, "testing", logChunk[0].Output)
require.Equal(t, "testing2", logChunk[1].Output)
})
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
rawOutput := "before\x00after"
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
Logs: []agentsdk.Log{
{
CreatedAt: dbtime.Now(),
Output: rawOutput,
},
},
})
require.NoError(t, err)
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
require.NoError(t, err)
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
require.NoError(t, err)
defer func() {
_ = closer.Close()
}()
var logChunk []codersdk.WorkspaceAgentLog
select {
case <-ctx.Done():
case logChunk = <-logs:
}
require.NoError(t, ctx.Err())
require.Len(t, logChunk, 1)
require.Equal(t, sanitizedOutput, logChunk[0].Output)
})
t.Run("Close logs on outdated build", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
+8 -13
View File
@@ -730,10 +730,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
if !ok {
return
}
log := s.Logger.With(
slog.F("agent_id", appToken.AgentID),
slog.F("workspace_id", appToken.WorkspaceID),
)
log := s.Logger.With(slog.F("agent_id", appToken.AgentID))
log.Debug(ctx, "resolved PTY request")
values := r.URL.Query()
@@ -768,21 +765,19 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
})
return
}
go httpapi.HeartbeatClose(ctx, s.Logger, cancel, conn)
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
go httpapi.HeartbeatClose(ctx, log, cancel, conn)
dialStart := time.Now()
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
if err != nil {
log.Debug(ctx, "dial workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "dial workspace agent", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return
}
defer release()
log.Debug(ctx, "dialed workspace agent", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "dialed workspace agent")
// #nosec G115 - Safe conversion for terminal height/width which are expected to be within uint16 range (0-65535)
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
arp.Container = container
@@ -790,12 +785,12 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
arp.BackendType = backendType
})
if err != nil {
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return
}
defer ptNetConn.Close()
log.Debug(ctx, "obtained PTY", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "obtained PTY")
report := newStatsReportFromSignedToken(*appToken)
s.collectStats(report)
@@ -805,7 +800,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
}()
agentssh.Bicopy(ctx, wsNetConn, ptNetConn)
log.Debug(ctx, "pty Bicopy finished", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
log.Debug(ctx, "pty Bicopy finished")
}
func (s *Server) collectStats(stats StatsReport) {
+585 -25
View File
@@ -788,6 +788,7 @@ type CreateOptions struct {
InitialUserContent []codersdk.ChatMessagePart
MCPServerIDs []uuid.UUID
Labels database.StringMap
DynamicTools json.RawMessage
}
// SendMessageBusyBehavior controls what happens when a chat is already active.
@@ -899,6 +900,10 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
RawMessage: labelsJSON,
Valid: true,
},
DynamicTools: pqtype.NullRawMessage{
RawMessage: opts.DynamicTools,
Valid: len(opts.DynamicTools) > 0,
},
})
if err != nil {
return xerrors.Errorf("insert chat: %w", err)
@@ -1546,6 +1551,238 @@ func (p *Server) PromoteQueued(
return result, nil
}
// SubmitToolResultsOptions controls tool result submission.
type SubmitToolResultsOptions struct {
ChatID uuid.UUID
UserID uuid.UUID
ModelConfigID uuid.UUID
Results []codersdk.ToolResult
DynamicTools json.RawMessage
}
// ToolResultValidationError indicates the submitted tool results
// failed validation (e.g. missing, duplicate, or unexpected IDs,
// or invalid JSON output).
type ToolResultValidationError struct {
Message string
Detail string
}
func (e *ToolResultValidationError) Error() string {
if e.Detail != "" {
return e.Message + ": " + e.Detail
}
return e.Message
}
// ToolResultStatusConflictError indicates the chat is not in the
// requires_action state expected for tool result submission.
type ToolResultStatusConflictError struct {
ActualStatus database.ChatStatus
}
func (e *ToolResultStatusConflictError) Error() string {
return fmt.Sprintf(
"chat status is %q, expected %q",
e.ActualStatus, database.ChatStatusRequiresAction,
)
}
// SubmitToolResults validates and persists client-provided tool
// results, transitions the chat to pending, and wakes the run
// loop. The caller is responsible for the fast-path status check;
// this method performs an authoritative re-check under a row lock.
func (p *Server) SubmitToolResults(
ctx context.Context,
opts SubmitToolResultsOptions,
) error {
dynamicToolNames, err := parseDynamicToolNames(pqtype.NullRawMessage{
RawMessage: opts.DynamicTools,
Valid: len(opts.DynamicTools) > 0,
})
if err != nil {
return xerrors.Errorf("parse chat dynamic tools: %w", err)
}
// The GetLastChatMessageByRole lookup and all subsequent
// validation and persistence run inside a single transaction
// so the assistant message cannot change between reads.
var statusConflict *ToolResultStatusConflictError
txErr := p.db.InTx(func(tx database.Store) error {
// Authoritative status check under row lock.
locked, lockErr := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
if lockErr != nil {
return xerrors.Errorf("lock chat for update: %w", lockErr)
}
if locked.Status != database.ChatStatusRequiresAction {
statusConflict = &ToolResultStatusConflictError{
ActualStatus: locked.Status,
}
return statusConflict
}
// Get the last assistant message inside the transaction
// for consistency with the row lock above.
lastAssistant, err := tx.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
ChatID: opts.ChatID,
Role: database.ChatMessageRoleAssistant,
})
if err != nil {
return xerrors.Errorf("get last assistant message: %w", err)
}
// Collect tool-call IDs that already have results.
// When a dynamic tool name collides with a built-in,
// the chatloop executes it as a built-in and persists
// the result. Those calls must not count as pending.
afterMsgs, afterErr := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: opts.ChatID,
AfterID: lastAssistant.ID,
})
if afterErr != nil {
return xerrors.Errorf("get messages after assistant: %w", afterErr)
}
handledCallIDs := make(map[string]bool)
for _, msg := range afterMsgs {
if msg.Role != database.ChatMessageRoleTool {
continue
}
msgParts, msgParseErr := chatprompt.ParseContent(msg)
if msgParseErr != nil {
continue
}
for _, mp := range msgParts {
if mp.Type == codersdk.ChatMessagePartTypeToolResult {
handledCallIDs[mp.ToolCallID] = true
}
}
}
// Extract pending dynamic tool-call IDs, skipping any
// that were already handled by the chatloop.
pendingCallIDs := make(map[string]bool)
toolCallIDToName := make(map[string]string)
parts, parseErr := chatprompt.ParseContent(lastAssistant)
if parseErr != nil {
return xerrors.Errorf("parse assistant message: %w", parseErr)
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolCall &&
dynamicToolNames[part.ToolName] &&
!handledCallIDs[part.ToolCallID] {
pendingCallIDs[part.ToolCallID] = true
toolCallIDToName[part.ToolCallID] = part.ToolName
}
}
// Validate submitted results match pending calls exactly.
submittedIDs := make(map[string]bool, len(opts.Results))
for _, result := range opts.Results {
if submittedIDs[result.ToolCallID] {
return &ToolResultValidationError{
Message: "Duplicate tool_call_id in results.",
Detail: fmt.Sprintf("Duplicate tool call ID %q.", result.ToolCallID),
}
}
submittedIDs[result.ToolCallID] = true
}
for id := range pendingCallIDs {
if !submittedIDs[id] {
return &ToolResultValidationError{
Message: "Missing tool result.",
Detail: fmt.Sprintf("Missing result for tool call %q.", id),
}
}
}
for id := range submittedIDs {
if !pendingCallIDs[id] {
return &ToolResultValidationError{
Message: "Unexpected tool result.",
Detail: fmt.Sprintf("No pending tool call with ID %q.", id),
}
}
}
// Marshal each tool result into a separate message row.
resultContents := make([]pqtype.NullRawMessage, 0, len(opts.Results))
for _, result := range opts.Results {
if !json.Valid(result.Output) {
return &ToolResultValidationError{
Message: "Tool result output must be valid JSON.",
Detail: fmt.Sprintf("Output for tool call %q is not valid JSON.", result.ToolCallID),
}
}
part := codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: result.ToolCallID,
ToolName: toolCallIDToName[result.ToolCallID],
Result: result.Output,
IsError: result.IsError,
}
marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{part})
if marshalErr != nil {
return xerrors.Errorf("marshal tool result: %w", marshalErr)
}
resultContents = append(resultContents, marshaled)
}
// Insert tool-result messages.
n := len(resultContents)
params := database.InsertChatMessagesParams{
ChatID: opts.ChatID,
CreatedBy: make([]uuid.UUID, n),
ModelConfigID: make([]uuid.UUID, n),
Role: make([]database.ChatMessageRole, n),
Content: make([]string, n),
ContentVersion: make([]int16, n),
Visibility: make([]database.ChatMessageVisibility, n),
InputTokens: make([]int64, n),
OutputTokens: make([]int64, n),
TotalTokens: make([]int64, n),
ReasoningTokens: make([]int64, n),
CacheCreationTokens: make([]int64, n),
CacheReadTokens: make([]int64, n),
ContextLimit: make([]int64, n),
Compressed: make([]bool, n),
TotalCostMicros: make([]int64, n),
RuntimeMs: make([]int64, n),
ProviderResponseID: make([]string, n),
}
for i, rc := range resultContents {
params.CreatedBy[i] = opts.UserID
params.ModelConfigID[i] = opts.ModelConfigID
params.Role[i] = database.ChatMessageRoleTool
params.Content[i] = string(rc.RawMessage)
params.ContentVersion[i] = chatprompt.CurrentContentVersion
params.Visibility[i] = database.ChatMessageVisibilityBoth
}
if _, insertErr := tx.InsertChatMessages(ctx, params); insertErr != nil {
return xerrors.Errorf("insert tool results: %w", insertErr)
}
// Transition chat to pending.
if _, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: opts.ChatID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
}); updateErr != nil {
return xerrors.Errorf("update chat status: %w", updateErr)
}
return nil
}, nil)
if txErr != nil {
return txErr
}
// Wake the chatd run loop so it processes the chat immediately.
p.signalWake()
return nil
}
// InterruptChat interrupts execution, sets waiting status, and broadcasts status updates.
func (p *Server) InterruptChat(
ctx context.Context,
@@ -1555,6 +1792,32 @@ func (p *Server) InterruptChat(
return chat
}
// If the chat is in requires_action, insert synthetic error
// tool-result messages for each pending dynamic tool call
// before transitioning to waiting. Without this, the LLM
// would see unmatched tool-call parts on the next run.
if chat.Status == database.ChatStatusRequiresAction {
if txErr := p.db.InTx(func(tx database.Store) error {
locked, lockErr := tx.GetChatByIDForUpdate(ctx, chat.ID)
if lockErr != nil {
return xerrors.Errorf("lock chat for interrupt: %w", lockErr)
}
// Another request may have already transitioned
// the chat (e.g. SubmitToolResults committed
// between our snapshot and this lock).
if locked.Status != database.ChatStatusRequiresAction {
return nil
}
return insertSyntheticToolResultsTx(ctx, tx, locked, "Tool execution interrupted by user")
}, nil); txErr != nil {
p.logger.Error(ctx, "failed to insert synthetic tool results during interrupt",
slog.F("chat_id", chat.ID),
slog.Error(txErr),
)
// Fall through — still try to set waiting status.
}
}
updatedChat, err := p.setChatWaiting(ctx, chat.ID)
if err != nil {
p.logger.Error(ctx, "failed to mark chat as waiting",
@@ -2345,7 +2608,7 @@ func insertUserMessageAndSetPending(
// queued while a chat is active.
func shouldQueueUserMessage(status database.ChatStatus) bool {
switch status {
case database.ChatStatusRunning, database.ChatStatusPending:
case database.ChatStatusRunning, database.ChatStatusPending, database.ChatStatusRequiresAction:
return true
default:
return false
@@ -3218,8 +3481,12 @@ func (p *Server) Subscribe(
// Pubsub will deliver a duplicate status
// later; the frontend deduplicates it
// (setChatStatus is idempotent).
// action_required is also transient and
// only published on the local stream, so
// it must be forwarded here.
if event.Type == codersdk.ChatStreamEventTypeMessagePart ||
event.Type == codersdk.ChatStreamEventTypeStatus {
event.Type == codersdk.ChatStreamEventTypeStatus ||
event.Type == codersdk.ChatStreamEventTypeActionRequired {
select {
case <-mergedCtx.Done():
return
@@ -3345,6 +3612,51 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
}
}
// pendingToStreamToolCalls converts a slice of chatloop pending
// tool calls into the SDK streaming representation.
func pendingToStreamToolCalls(pending []chatloop.PendingToolCall) []codersdk.ChatStreamToolCall {
calls := make([]codersdk.ChatStreamToolCall, len(pending))
for i, tc := range pending {
calls[i] = codersdk.ChatStreamToolCall{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Args: tc.Args,
}
}
return calls
}
// publishChatActionRequired broadcasts an action_required event via
// PostgreSQL pubsub so that global watchers can react to dynamic
// tool calls without streaming each chat individually.
func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloop.PendingToolCall) {
if p.pubsub == nil {
return
}
toolCalls := pendingToStreamToolCalls(pending)
sdkChat := db2sdk.Chat(chat, nil, nil)
event := coderdpubsub.ChatEvent{
Kind: coderdpubsub.ChatEventKindActionRequired,
Chat: sdkChat,
ToolCalls: toolCalls,
}
payload, err := json.Marshal(event)
if err != nil {
p.logger.Error(context.Background(), "failed to marshal chat action_required pubsub event",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return
}
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
}
}
// PublishDiffStatusChange broadcasts a diff_status_change event for
// the given chat so that watching clients know to re-fetch the diff
// status. This is called from the HTTP layer after the diff status
@@ -3849,6 +4161,21 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
}
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
// When the chat is parked in requires_action,
// publish the stream event and global pubsub event
// after the DB status has committed. Publishing
// here (not in runChat) prevents a race where a
// fast client reacts before the status is visible.
if status == database.ChatStatusRequiresAction && len(runResult.PendingDynamicToolCalls) > 0 {
toolCalls := pendingToStreamToolCalls(runResult.PendingDynamicToolCalls)
p.publishEvent(chat.ID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeActionRequired,
ActionRequired: &codersdk.ChatStreamActionRequired{
ToolCalls: toolCalls,
},
})
p.publishChatActionRequired(updatedChat, runResult.PendingDynamicToolCalls)
}
if !wasInterrupted {
p.maybeSendPushNotification(cleanupCtx, updatedChat, status, lastError, runResult, logger)
}
@@ -3877,6 +4204,13 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
return
}
// The LLM invoked a dynamic tool — park the chat in
// requires_action so the client can supply tool results.
if len(runResult.PendingDynamicToolCalls) > 0 {
status = database.ChatStatusRequiresAction
return
}
// If runChat completed successfully but the server context was
// canceled (e.g. during Close()), the chat should be returned
// to pending so another replica can pick it up. There is a
@@ -3943,9 +4277,10 @@ func (t *generatedChatTitle) Load() (string, bool) {
}
type runChatResult struct {
FinalAssistantText string
PushSummaryModel fantasy.LanguageModel
ProviderKeys chatprovider.ProviderAPIKeys
FinalAssistantText string
PushSummaryModel fantasy.LanguageModel
ProviderKeys chatprovider.ProviderAPIKeys
PendingDynamicToolCalls []chatloop.PendingToolCall
}
func (p *Server) runChat(
@@ -4249,8 +4584,8 @@ func (p *Server) runChat(
// server.
toolNameToConfigID := make(map[string]uuid.UUID)
for _, t := range mcpTools {
if mcp, ok := t.(mcpclient.MCPToolIdentifier); ok {
toolNameToConfigID[t.Info().Name] = mcp.MCPServerConfigID()
if mcpTool, ok := t.(mcpclient.MCPToolIdentifier); ok {
toolNameToConfigID[t.Info().Name] = mcpTool.MCPServerConfigID()
}
}
@@ -4269,6 +4604,7 @@ func (p *Server) runChat(
// (which is the common case).
modelConfigContextLimit := modelConfig.ContextLimit
var finalAssistantText string
var pendingDynamicCalls []chatloop.PendingToolCall
persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error {
// If the chat context has been canceled, bail out before
@@ -4288,6 +4624,10 @@ func (p *Server) runChat(
return persistCtx.Err()
}
// Capture pending dynamic tool calls so the caller
// can surface them after chatloop.Run returns.
pendingDynamicCalls = step.PendingDynamicToolCalls
// Split the step content into assistant blocks and tool
// result blocks so they can be stored as separate messages
// with the appropriate roles. Provider-executed tool results
@@ -4325,6 +4665,21 @@ func (p *Server) runChat(
part.MCPServerConfigID = uuid.NullUUID{UUID: configID, Valid: true}
}
}
// Apply recorded timestamps so persisted
// tool-call parts carry accurate CreatedAt.
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolCallID != "" && step.ToolCallCreatedAt != nil {
if ts, ok := step.ToolCallCreatedAt[part.ToolCallID]; ok {
part.CreatedAt = &ts
}
}
// Provider-executed tool results appear in
// assistantBlocks rather than toolResults,
// so apply their timestamps here as well.
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolCallID != "" && step.ToolResultCreatedAt != nil {
if ts, ok := step.ToolResultCreatedAt[part.ToolCallID]; ok {
part.CreatedAt = &ts
}
}
sdkParts = append(sdkParts, part)
}
finalAssistantText = strings.TrimSpace(contentBlocksToText(sdkParts))
@@ -4343,6 +4698,13 @@ func (p *Server) runChat(
trPart.MCPServerConfigID = uuid.NullUUID{UUID: configID, Valid: true}
}
}
// Apply recorded timestamps so persisted
// tool-result parts carry accurate CreatedAt.
if trPart.ToolCallID != "" && step.ToolResultCreatedAt != nil {
if ts, ok := step.ToolResultCreatedAt[trPart.ToolCallID]; ok {
trPart.CreatedAt = &ts
}
}
var marshalErr error
toolResultContents[i], marshalErr = chatprompt.MarshalParts([]codersdk.ChatMessagePart{trPart})
if marshalErr != nil {
@@ -4674,6 +5036,39 @@ func (p *Server) runChat(
tools = append(tools, mcpTools...)
tools = append(tools, workspaceMCPTools...)
// Append dynamic tools declared by the client at chat
// creation time. These appear in the LLM's tool list but
// are never executed by the chatloop — the client handles
// execution via POST /tool-results.
dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools)
if err != nil {
return result, xerrors.Errorf("parse dynamic tool names: %w", err)
}
// Unmarshal the full definitions separately so we can
// build the filtered list below. parseDynamicToolNames
// already validated the JSON, so this cannot fail.
var dynamicToolDefs []codersdk.DynamicTool
if chat.DynamicTools.Valid {
if err := json.Unmarshal(chat.DynamicTools.RawMessage, &dynamicToolDefs); err != nil {
return result, xerrors.Errorf("unmarshal dynamic tools: %w", err)
}
}
for _, t := range tools {
info := t.Info()
if dynamicToolNames[info.Name] {
logger.Warn(ctx, "dynamic tool name collides with built-in tool, built-in takes precedence",
slog.F("tool_name", info.Name))
delete(dynamicToolNames, info.Name)
}
}
var filteredDefs []codersdk.DynamicTool
for _, dt := range dynamicToolDefs {
if dynamicToolNames[dt.Name] {
filteredDefs = append(filteredDefs, dt)
}
}
tools = append(tools, dynamicToolsFromSDK(p.logger, filteredDefs)...)
// Build provider-native tools (e.g., web search) based on
// the model configuration.
var providerTools []chatloop.ProviderTool
@@ -4717,8 +5112,7 @@ func (p *Server) runChat(
)
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
}
err := chatloop.Run(ctx, chatloop.RunOptions{
err = chatloop.Run(ctx, chatloop.RunOptions{
Model: model,
Messages: prompt,
Tools: tools, MaxSteps: maxChatSteps,
@@ -4726,6 +5120,9 @@ func (p *Server) runChat(
ModelConfig: callConfig,
ProviderOptions: providerOptions,
ProviderTools: providerTools,
// dynamicToolNames now contains only names that don't
// collide with built-in/MCP tools.
DynamicToolNames: dynamicToolNames,
ContextLimitFallback: modelConfigContextLimit,
@@ -4803,6 +5200,15 @@ func (p *Server) runChat(
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
},
})
if errors.Is(err, chatloop.ErrDynamicToolCall) {
// The stream event is published in processChat's
// defer after the DB status transitions to
// requires_action, preventing a race where a fast
// client reacts before the status is committed.
result.FinalAssistantText = finalAssistantText
result.PendingDynamicToolCalls = pendingDynamicCalls
return result, nil
}
if err != nil {
classified := chaterror.Classify(err).WithProvider(model.Provider())
return result, chaterror.WithClassification(err, classified)
@@ -5424,7 +5830,9 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
recovered := 0
for _, chat := range staleChats {
p.logger.Info(ctx, "recovering stale chat", slog.F("chat_id", chat.ID))
p.logger.Info(ctx, "recovering stale chat",
slog.F("chat_id", chat.ID),
slog.F("status", chat.Status))
// Use a transaction with FOR UPDATE to avoid a TOCTOU race:
// between GetStaleChats (a bare SELECT) and here, the chat's
@@ -5436,34 +5844,73 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
return xerrors.Errorf("lock chat for recovery: %w", lockErr)
}
// Only recover chats that are still running.
// Between GetStaleChats and this lock, the chat
// may have completed normally.
if locked.Status != database.ChatStatusRunning {
switch locked.Status {
case database.ChatStatusRunning:
// Re-check: only recover if the chat is still stale.
// A valid heartbeat at or after the threshold means
// the chat was refreshed after our snapshot.
if locked.HeartbeatAt.Valid && !locked.HeartbeatAt.Time.Before(staleAfter) {
p.logger.Debug(ctx, "chat heartbeat refreshed since snapshot, skipping recovery",
slog.F("chat_id", chat.ID))
return nil
}
case database.ChatStatusRequiresAction:
// Re-check: the chat may have been updated after
// our snapshot, similar to the heartbeat check for
// running chats.
if !locked.UpdatedAt.Before(staleAfter) {
p.logger.Debug(ctx, "chat updated since snapshot, skipping recovery",
slog.F("chat_id", chat.ID))
return nil
}
default:
// Status changed since our snapshot; skip.
p.logger.Debug(ctx, "chat status changed since snapshot, skipping recovery",
slog.F("chat_id", chat.ID),
slog.F("status", locked.Status))
return nil
}
// Re-check: only recover if the chat is still stale.
// A valid heartbeat that is at or after the stale
// threshold means the chat was refreshed after our
// initial snapshot — skip it.
if locked.HeartbeatAt.Valid && !locked.HeartbeatAt.Time.Before(staleAfter) {
p.logger.Debug(ctx, "chat heartbeat refreshed since snapshot, skipping recovery",
slog.F("chat_id", chat.ID))
return nil
lastError := sql.NullString{}
if locked.Status == database.ChatStatusRequiresAction {
lastError = sql.NullString{
String: "Dynamic tool execution timed out",
Valid: true,
}
}
// Reset to pending so any replica can pick it up.
recoverStatus := database.ChatStatusPending
if locked.Status == database.ChatStatusRequiresAction {
// Timed-out requires_action chats have dangling
// tool calls with no matching results. Setting
// them back to pending would replay incomplete
// tool calls to the LLM, so mark them as errors.
recoverStatus = database.ChatStatusError
}
// Insert synthetic error tool-result messages
// so the LLM history remains valid if the user
// retries the chat later.
if locked.Status == database.ChatStatusRequiresAction {
if synthErr := insertSyntheticToolResultsTx(ctx, tx, locked, "Dynamic tool execution timed out"); synthErr != nil {
p.logger.Warn(ctx, "failed to insert synthetic tool results during stale recovery",
slog.F("chat_id", chat.ID),
slog.Error(synthErr),
)
// Continue with error status even if
// synthetic results fail to insert.
}
}
// Reset so any replica can pick it up (pending) or
// the client sees the failure (error).
_, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusPending,
Status: recoverStatus,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
LastError: lastError,
})
if updateErr != nil {
return updateErr
@@ -5482,6 +5929,119 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
}
}
// insertSyntheticToolResultsTx inserts error tool-result messages for
// every pending dynamic tool call in the last assistant message. This
// keeps the LLM message history valid (every tool-call has a matching
// tool-result) when a requires_action chat times out or is interrupted.
// It operates on the provided store, which may be a transaction handle.
func insertSyntheticToolResultsTx(
ctx context.Context,
store database.Store,
chat database.Chat,
reason string,
) error {
dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools)
if err != nil {
return xerrors.Errorf("parse dynamic tools: %w", err)
}
if len(dynamicToolNames) == 0 {
return nil
}
// Get the last assistant message to find pending tool calls.
lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
ChatID: chat.ID,
Role: database.ChatMessageRoleAssistant,
})
if err != nil {
return xerrors.Errorf("get last assistant message: %w", err)
}
parts, err := chatprompt.ParseContent(lastAssistant)
if err != nil {
return xerrors.Errorf("parse assistant message: %w", err)
}
// Collect dynamic tool calls that need synthetic results.
var resultContents []pqtype.NullRawMessage
for _, part := range parts {
if part.Type != codersdk.ChatMessagePartTypeToolCall || !dynamicToolNames[part.ToolName] {
continue
}
resultPart := codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: part.ToolCallID,
ToolName: part.ToolName,
Result: json.RawMessage(fmt.Sprintf("%q", reason)),
IsError: true,
}
marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{resultPart})
if marshalErr != nil {
return xerrors.Errorf("marshal synthetic tool result: %w", marshalErr)
}
resultContents = append(resultContents, marshaled)
}
if len(resultContents) == 0 {
return nil
}
// Insert tool-result messages using the same pattern as
// SubmitToolResults.
n := len(resultContents)
params := database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: make([]uuid.UUID, n),
ModelConfigID: make([]uuid.UUID, n),
Role: make([]database.ChatMessageRole, n),
Content: make([]string, n),
ContentVersion: make([]int16, n),
Visibility: make([]database.ChatMessageVisibility, n),
InputTokens: make([]int64, n),
OutputTokens: make([]int64, n),
TotalTokens: make([]int64, n),
ReasoningTokens: make([]int64, n),
CacheCreationTokens: make([]int64, n),
CacheReadTokens: make([]int64, n),
ContextLimit: make([]int64, n),
Compressed: make([]bool, n),
TotalCostMicros: make([]int64, n),
RuntimeMs: make([]int64, n),
ProviderResponseID: make([]string, n),
}
for i, rc := range resultContents {
params.CreatedBy[i] = uuid.Nil
params.ModelConfigID[i] = chat.LastModelConfigID
params.Role[i] = database.ChatMessageRoleTool
params.Content[i] = string(rc.RawMessage)
params.ContentVersion[i] = chatprompt.CurrentContentVersion
params.Visibility[i] = database.ChatMessageVisibilityBoth
}
if _, err := store.InsertChatMessages(ctx, params); err != nil {
return xerrors.Errorf("insert synthetic tool results: %w", err)
}
return nil
}
// parseDynamicToolNames unmarshals the dynamic tools JSON column
// and returns a map of tool names. This centralizes the repeated
// pattern of deserializing DynamicTools into a name set.
func parseDynamicToolNames(raw pqtype.NullRawMessage) (map[string]bool, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return make(map[string]bool), nil
}
var tools []codersdk.DynamicTool
if err := json.Unmarshal(raw.RawMessage, &tools); err != nil {
return nil, xerrors.Errorf("unmarshal dynamic tools: %w", err)
}
names := make(map[string]bool, len(tools))
for _, t := range tools {
names[t.Name] = true
}
return names, nil
}
// maybeSendPushNotification sends a web push notification when an
// agent chat reaches a terminal state. For errors it dispatches
// synchronously; for successful completions it spawns a goroutine
+576
View File
@@ -1531,6 +1531,70 @@ func TestRecoverStaleChatsPeriodically(t *testing.T) {
}, testutil.WaitMedium, testutil.IntervalFast)
}
func TestRecoverStaleRequiresActionChat(t *testing.T) {
t.Parallel()
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Use a very short stale threshold so the periodic recovery
// kicks in quickly during the test.
staleAfter := 500 * time.Millisecond
// Create a chat and set it to requires_action to simulate a
// client that disappeared while the chat was waiting for
// dynamic tool results.
chat, err := db.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
Title: "stale-requires-action",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRequiresAction,
})
require.NoError(t, err)
// Backdate updated_at so the chat appears stale to the
// recovery loop without needing time.Sleep.
_, err = rawDB.ExecContext(ctx,
"UPDATE chats SET updated_at = $1 WHERE id = $2",
time.Now().Add(-time.Hour), chat.ID)
require.NoError(t, err)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: testutil.WaitLong,
InFlightChatStaleAfter: staleAfter,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
// The stale recovery should transition the requires_action
// chat to error with the timeout message.
var chatResult database.Chat
require.Eventually(t, func() bool {
chatResult, err = db.GetChatByID(ctx, chat.ID)
if err != nil {
return false
}
return chatResult.Status == database.ChatStatusError
}, testutil.WaitMedium, testutil.IntervalFast)
require.Contains(t, chatResult.LastError.String, "Dynamic tool execution timed out")
require.False(t, chatResult.WorkerID.Valid)
}
func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) {
t.Parallel()
@@ -1882,6 +1946,518 @@ func TestPersistToolResultWithBinaryData(t *testing.T) {
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include execute tool output")
}
func TestDynamicToolCallPausesAndResumes(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Track streaming calls to the mock LLM.
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
// Non-streaming requests are title generation — return a
// simple title.
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Dynamic tool test")
}
// Capture the full request for later assertions.
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
Stream: req.Stream,
})
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
// First call: the LLM invokes our dynamic tool.
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"my_dynamic_tool",
`{"input":"hello world"}`,
),
)
}
// Second call: the LLM returns a normal text response.
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Dynamic tool result received.")...,
)
})
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
// Dynamic tools do not need a workspace connection, but the
// chatd server always builds workspace tools. Use an active
// server without an agent connection — the built-in tools
// are never invoked because the only tool call targets our
// dynamic tool.
server := newActiveTestServer(t, db, ps)
// Create a chat with a dynamic tool.
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
Name: "my_dynamic_tool",
Description: "A test dynamic tool.",
InputSchema: mcpgo.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{"type": "string"},
},
Required: []string{"input"},
},
}})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "dynamic-tool-pause-resume",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Please call the dynamic tool."),
},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 1. Wait for the chat to reach requires_action status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusRequiresAction ||
got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
"expected requires_action, got %s (last_error=%q)",
chatResult.Status, chatResult.LastError.String)
// 2. Read the assistant message to find the tool-call ID.
var toolCallID string
var toolCallFound bool
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
if msg.Role != database.ChatMessageRoleAssistant {
continue
}
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
toolCallID = part.ToolCallID
toolCallFound = true
return true
}
}
}
return false
}, testutil.IntervalFast)
require.True(t, toolCallFound, "expected to find tool call for my_dynamic_tool")
require.NotEmpty(t, toolCallID)
// 3. Submit tool results via SubmitToolResults.
toolResultOutput := json.RawMessage(`{"result":"dynamic tool output"}`)
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: user.ID,
ModelConfigID: chatResult.LastModelConfigID,
Results: []codersdk.ToolResult{{
ToolCallID: toolCallID,
Output: toolResultOutput,
}},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 4. Wait for the chat to reach a terminal status.
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
// 5. Verify the chat completed successfully.
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
}
// 6. Verify the mock received exactly 2 streaming calls.
require.Equal(t, int32(2), streamedCallCount.Load(),
"expected exactly 2 streaming calls to the LLM")
streamedCallsMu.Lock()
recordedCalls := append([]chattest.OpenAIRequest(nil), streamedCalls...)
streamedCallsMu.Unlock()
require.Len(t, recordedCalls, 2)
// 7. Verify the dynamic tool appeared in the first call's tool list.
var foundDynamicTool bool
for _, tool := range recordedCalls[0].Tools {
if tool.Function.Name == "my_dynamic_tool" {
foundDynamicTool = true
break
}
}
require.True(t, foundDynamicTool,
"expected 'my_dynamic_tool' in the first LLM call's tool list")
// 8. Verify the second call's messages contain the tool result.
var foundToolResultInSecondCall bool
for _, message := range recordedCalls[1].Messages {
if message.Role != "tool" {
continue
}
if strings.Contains(message.Content, "dynamic tool output") {
foundToolResultInSecondCall = true
break
}
}
require.True(t, foundToolResultInSecondCall,
"expected second LLM call to include the submitted dynamic tool result")
}
func TestDynamicToolCallMixedWithBuiltIn(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Track streaming calls to the mock LLM.
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Mixed tool test")
}
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
Stream: req.Stream,
})
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
// First call: return TWO tool calls in one
// response — a built-in tool (read_file) and a
// dynamic tool (my_dynamic_tool).
builtinChunk := chattest.OpenAIToolCallChunk(
"read_file",
`{"path":"/tmp/test.txt"}`,
)
dynamicChunk := chattest.OpenAIToolCallChunk(
"my_dynamic_tool",
`{"input":"hello world"}`,
)
// Merge both tool calls into one chunk with
// separate indices so the LLM appears to have
// requested both tools simultaneously.
mergedChunk := builtinChunk
dynCall := dynamicChunk.Choices[0].ToolCalls[0]
dynCall.Index = 1
mergedChunk.Choices[0].ToolCalls = append(
mergedChunk.Choices[0].ToolCalls,
dynCall,
)
return chattest.OpenAIStreamingResponse(mergedChunk)
}
// Second call (after tool results): normal text
// response.
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("All done.")...,
)
})
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
server := newActiveTestServer(t, db, ps)
// Create a chat with a dynamic tool.
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
Name: "my_dynamic_tool",
Description: "A test dynamic tool.",
InputSchema: mcpgo.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{"type": "string"},
},
Required: []string{"input"},
},
}})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "mixed-builtin-dynamic",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Call both tools."),
},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 1. Wait for the chat to reach requires_action status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusRequiresAction ||
got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
"expected requires_action, got %s (last_error=%q)",
chatResult.Status, chatResult.LastError.String)
// 2. Verify the built-in tool (read_file) was already
// executed by checking that a tool result message
// exists for it in the database.
var builtinToolResultFound bool
var toolCallID string
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
// Check for the built-in tool result.
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "read_file" {
builtinToolResultFound = true
}
// Find the dynamic tool call ID.
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
toolCallID = part.ToolCallID
}
}
}
return builtinToolResultFound && toolCallID != ""
}, testutil.IntervalFast)
require.True(t, builtinToolResultFound,
"expected read_file tool result in the DB before dynamic tool resolution")
require.NotEmpty(t, toolCallID)
// 3. Submit dynamic tool results.
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: user.ID,
ModelConfigID: chatResult.LastModelConfigID,
Results: []codersdk.ToolResult{{
ToolCallID: toolCallID,
Output: json.RawMessage(`{"result":"dynamic output"}`),
}},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 4. Wait for the chat to complete.
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
}
// 5. Verify the LLM received exactly 2 streaming calls.
require.Equal(t, int32(2), streamedCallCount.Load(),
"expected exactly 2 streaming calls to the LLM")
}
func TestSubmitToolResultsConcurrency(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// The mock LLM returns a dynamic tool call on the first streaming
// request, then a plain text reply on the second.
var streamedCallCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Concurrency test")
}
if streamedCallCount.Add(1) == 1 {
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"my_dynamic_tool",
`{"input":"hello"}`,
),
)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Done.")...,
)
})
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
server := newActiveTestServer(t, db, ps)
// Create a chat with a dynamic tool.
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
Name: "my_dynamic_tool",
Description: "A test dynamic tool.",
InputSchema: mcpgo.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{"type": "string"},
},
Required: []string{"input"},
},
}})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "concurrency-tool-results",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Please call the dynamic tool."),
},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// Wait for the chat to reach requires_action status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusRequiresAction ||
got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
"expected requires_action, got %s (last_error=%q)",
chatResult.Status, chatResult.LastError.String)
// Find the tool call ID from the assistant message.
var toolCallID string
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
if msg.Role != database.ChatMessageRoleAssistant {
continue
}
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
toolCallID = part.ToolCallID
return true
}
}
}
return false
}, testutil.IntervalFast)
require.NotEmpty(t, toolCallID)
// Spawn N goroutines that all try to submit tool results at the
// same time. Exactly one should succeed; the rest must get a
// ToolResultStatusConflictError.
const numGoroutines = 10
var (
wg sync.WaitGroup
ready = make(chan struct{})
successes atomic.Int32
conflicts atomic.Int32
unexpectedErrors = make(chan error, numGoroutines)
)
for range numGoroutines {
wg.Go(func() {
// Wait for all goroutines to be ready.
<-ready
submitErr := server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: user.ID,
ModelConfigID: chatResult.LastModelConfigID,
Results: []codersdk.ToolResult{{
ToolCallID: toolCallID,
Output: json.RawMessage(`{"result":"concurrent output"}`),
}},
DynamicTools: dynamicToolsJSON,
})
if submitErr == nil {
successes.Add(1)
return
}
var conflict *chatd.ToolResultStatusConflictError
if errors.As(submitErr, &conflict) {
conflicts.Add(1)
return
}
// Collect unexpected errors for assertion
// outside the goroutine (require.NoError
// calls t.FailNow which is illegal here).
unexpectedErrors <- submitErr
})
}
// Release all goroutines at once.
close(ready)
wg.Wait()
close(unexpectedErrors)
for ue := range unexpectedErrors {
require.NoError(t, ue, "unexpected error from SubmitToolResults")
}
require.Equal(t, int32(1), successes.Load(),
"expected exactly 1 goroutine to succeed")
require.Equal(t, int32(numGoroutines-1), conflicts.Load(),
"expected %d conflict errors", numGoroutines-1)
}
func ptrRef[T any](v T) *T {
return &v
}
+206 -25
View File
@@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/json"
"errors"
"maps"
"slices"
"strconv"
"strings"
@@ -18,6 +19,7 @@ import (
"charm.land/fantasy/schema"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
@@ -38,13 +40,23 @@ const (
)
var (
ErrInterrupted = xerrors.New("chat interrupted")
ErrInterrupted = xerrors.New("chat interrupted")
ErrDynamicToolCall = xerrors.New("dynamic tool call")
errStartupTimeout = xerrors.New(
"chat response did not start before the startup timeout",
)
)
// PendingToolCall describes a tool call that targets a dynamic
// tool. These calls are not executed by the chatloop; instead
// they are persisted so the caller can fulfill them externally.
type PendingToolCall struct {
ToolCallID string
ToolName string
Args string
}
// PersistedStep contains the full content of a completed or
// interrupted agent step. Content includes both assistant blocks
// (text, reasoning, tool calls) and tool result blocks. The
@@ -60,6 +72,21 @@ type PersistedStep struct {
// Zero indicates the duration was not measured (e.g.
// interrupted steps).
Runtime time.Duration
// PendingDynamicToolCalls lists tool calls that target
// dynamic tools. When non-empty the chatloop exits with
// ErrDynamicToolCall so the caller can execute them
// externally and resume the loop.
PendingDynamicToolCalls []PendingToolCall
// ToolCallCreatedAt maps tool-call IDs to the time
// the model emitted each tool call. Applied by the
// persistence layer to set CreatedAt on persisted
// tool-call ChatMessageParts.
ToolCallCreatedAt map[string]time.Time
// ToolResultCreatedAt maps tool-call IDs to the time
// each tool result was produced (or interrupted).
// Applied by the persistence layer to set CreatedAt
// on persisted tool-result ChatMessageParts.
ToolResultCreatedAt map[string]time.Time
}
// RunOptions configures a single streaming chat loop run.
@@ -77,6 +104,12 @@ type RunOptions struct {
ActiveTools []string
ContextLimitFallback int64
// DynamicToolNames lists tool names that are handled
// externally. When the model invokes one of these tools
// the chatloop persists partial results and exits with
// ErrDynamicToolCall instead of executing the tool.
DynamicToolNames map[string]bool
// ModelConfig holds per-call LLM parameters (temperature,
// max tokens, etc.) read from the chat model configuration.
ModelConfig codersdk.ChatModelCallConfig
@@ -128,12 +161,14 @@ type ProviderTool struct {
// step. Since we own the stream consumer, all content is tracked
// directly here — no shadow draft state needed.
type stepResult struct {
content []fantasy.Content
usage fantasy.Usage
providerMetadata fantasy.ProviderMetadata
finishReason fantasy.FinishReason
toolCalls []fantasy.ToolCallContent
shouldContinue bool
content []fantasy.Content
usage fantasy.Usage
providerMetadata fantasy.ProviderMetadata
finishReason fantasy.FinishReason
toolCalls []fantasy.ToolCallContent
shouldContinue bool
toolCallCreatedAt map[string]time.Time
toolResultCreatedAt map[string]time.Time
}
// toResponseMessages converts step content into messages suitable
@@ -385,16 +420,72 @@ func Run(ctx context.Context, opts RunOptions) error {
return ctx.Err()
}
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, result.toolCalls, func(tr fantasy.ToolResultContent) {
publishMessagePart(
codersdk.ChatMessageRoleTool,
chatprompt.PartFromContent(tr),
)
// Partition tool calls into built-in and dynamic.
var builtinCalls, dynamicCalls []fantasy.ToolCallContent
if len(opts.DynamicToolNames) > 0 {
for _, tc := range result.toolCalls {
if opts.DynamicToolNames[tc.ToolName] {
dynamicCalls = append(dynamicCalls, tc)
} else {
builtinCalls = append(builtinCalls, tc)
}
}
} else {
builtinCalls = result.toolCalls
}
// Execute only built-in tools.
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, builtinCalls, func(tr fantasy.ToolResultContent, completedAt time.Time) {
recordToolResultTimestamp(&result, tr.ToolCallID, completedAt)
ssePart := chatprompt.PartFromContent(tr)
ssePart.CreatedAt = &completedAt
publishMessagePart(codersdk.ChatMessageRoleTool, ssePart)
})
for _, tr := range toolResults {
result.content = append(result.content, tr)
}
// If dynamic tools were called, persist what we
// have (assistant + built-in results) and exit so
// the caller can execute them externally.
if len(dynamicCalls) > 0 {
pending := make([]PendingToolCall, 0, len(dynamicCalls))
for _, dc := range dynamicCalls {
pending = append(pending, PendingToolCall{
ToolCallID: dc.ToolCallID,
ToolName: dc.ToolName,
Args: dc.Input,
})
}
contextLimit := extractContextLimit(result.providerMetadata)
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
contextLimit = sql.NullInt64{
Int64: opts.ContextLimitFallback,
Valid: true,
}
}
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
Runtime: time.Since(stepStart),
PendingDynamicToolCalls: pending,
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
return ErrInterrupted
}
return xerrors.Errorf("persist step: %w", err)
}
tryCompactOnExit(ctx, opts, result.usage, result.providerMetadata)
return ErrDynamicToolCall
}
// Check for interruption after tool execution.
// Tools that were canceled mid-flight produce error
// results via ctx cancellation. Persist the full
@@ -421,11 +512,13 @@ func Run(ctx context.Context, opts RunOptions) error {
// check and here, fall back to the interrupt-safe
// path so partial content is not lost.
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
Runtime: time.Since(stepStart),
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
Runtime: time.Since(stepStart),
ToolCallCreatedAt: result.toolCallCreatedAt,
ToolResultCreatedAt: result.toolResultCreatedAt,
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
@@ -758,9 +851,20 @@ func processStepStream(
// Clean up active tool call tracking.
delete(activeToolCalls, part.ID)
// Record when the model emitted this tool call
// so the persisted part carries an accurate
// timestamp for duration computation.
now := dbtime.Now()
if result.toolCallCreatedAt == nil {
result.toolCallCreatedAt = make(map[string]time.Time)
}
result.toolCallCreatedAt[part.ID] = now
ssePart := chatprompt.PartFromContent(tc)
ssePart.CreatedAt = &now
publishMessagePart(
codersdk.ChatMessageRoleAssistant,
chatprompt.PartFromContent(tc),
ssePart,
)
case fantasy.StreamPartTypeSource:
@@ -790,9 +894,18 @@ func processStepStream(
ProviderMetadata: part.ProviderMetadata,
}
result.content = append(result.content, tr)
now := dbtime.Now()
if result.toolResultCreatedAt == nil {
result.toolResultCreatedAt = make(map[string]time.Time)
}
result.toolResultCreatedAt[part.ID] = now
ssePart := chatprompt.PartFromContent(tr)
ssePart.CreatedAt = &now
publishMessagePart(
codersdk.ChatMessageRoleTool,
chatprompt.PartFromContent(tr),
ssePart,
)
}
case fantasy.StreamPartTypeFinish:
@@ -861,7 +974,7 @@ func executeTools(
allTools []fantasy.AgentTool,
providerTools []ProviderTool,
toolCalls []fantasy.ToolCallContent,
onResult func(fantasy.ToolResultContent),
onResult func(fantasy.ToolResultContent, time.Time),
) []fantasy.ToolResultContent {
if len(toolCalls) == 0 {
return nil
@@ -894,10 +1007,11 @@ func executeTools(
}
results := make([]fantasy.ToolResultContent, len(localToolCalls))
completedAt := make([]time.Time, len(localToolCalls))
var wg sync.WaitGroup
wg.Add(len(localToolCalls))
for i, tc := range localToolCalls {
go func(i int, tc fantasy.ToolCallContent) {
go func() {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
@@ -909,17 +1023,21 @@ func executeTools(
},
}
}
// Record when this tool completed (or panicked).
// Captured per-goroutine so parallel tools get
// accurate individual completion times.
completedAt[i] = dbtime.Now()
}()
results[i] = executeSingleTool(ctx, toolMap, tc)
}(i, tc)
}()
}
wg.Wait()
// Publish results in the original tool-call order so SSE
// subscribers see a deterministic event sequence.
if onResult != nil {
for _, tr := range results {
onResult(tr)
for i, tr := range results {
onResult(tr, completedAt[i])
}
}
return results
@@ -1055,11 +1173,24 @@ func persistInterruptedStep(
}
}
// Copy existing timestamps and add result timestamps for
// interrupted tool calls so the frontend can show partial
// duration.
toolCallCreatedAt := maps.Clone(result.toolCallCreatedAt)
if toolCallCreatedAt == nil {
toolCallCreatedAt = make(map[string]time.Time)
}
toolResultCreatedAt := maps.Clone(result.toolResultCreatedAt)
if toolResultCreatedAt == nil {
toolResultCreatedAt = make(map[string]time.Time)
}
// Build combined content: all accumulated content + synthetic
// interrupted results for any unanswered tool calls.
content := make([]fantasy.Content, 0, len(result.content))
content = append(content, result.content...)
interruptedAt := dbtime.Now()
for _, tc := range result.toolCalls {
if tc.ToolCallID == "" {
continue
@@ -1075,12 +1206,20 @@ func persistInterruptedStep(
Error: xerrors.New(interruptedToolResultErrorMessage),
},
})
// Only stamp synthetic results; don't clobber
// timestamps from tools that completed before
// the interruption arrived.
if _, exists := toolResultCreatedAt[tc.ToolCallID]; !exists {
toolResultCreatedAt[tc.ToolCallID] = interruptedAt
}
answeredToolCalls[tc.ToolCallID] = struct{}{}
}
persistCtx := context.WithoutCancel(ctx)
if err := opts.PersistStep(persistCtx, PersistedStep{
Content: content,
Content: content,
ToolCallCreatedAt: toolCallCreatedAt,
ToolResultCreatedAt: toolResultCreatedAt,
}); err != nil {
if opts.OnInterruptedPersistError != nil {
opts.OnInterruptedPersistError(err)
@@ -1088,6 +1227,38 @@ func persistInterruptedStep(
}
}
// tryCompactOnExit runs compaction when the chatloop is about
// to exit early (e.g. via ErrDynamicToolCall). The normal
// inline and post-run compaction paths are unreachable in
// early-exit scenarios, so this ensures the context window
// doesn't grow unbounded.
func tryCompactOnExit(
ctx context.Context,
opts RunOptions,
usage fantasy.Usage,
metadata fantasy.ProviderMetadata,
) {
if opts.Compaction == nil || opts.ReloadMessages == nil {
return
}
reloaded, err := opts.ReloadMessages(ctx)
if err != nil {
return
}
_, compactErr := tryCompact(
ctx,
opts.Model,
opts.Compaction,
opts.ContextLimitFallback,
usage,
metadata,
reloaded,
)
if compactErr != nil && opts.Compaction.OnError != nil {
opts.Compaction.OnError(compactErr)
}
}
// buildToolDefinitions converts AgentTool definitions into the
// fantasy.Tool slice expected by fantasy.Call. When activeTools
// is non-empty, only function tools whose name appears in the
@@ -1239,6 +1410,16 @@ func isResponsesStoreEnabled(providerOptions fantasy.ProviderOptions) bool {
return false
}
// recordToolResultTimestamp lazily initializes the
// toolResultCreatedAt map on the stepResult and records
// the completion timestamp for the given tool-call ID.
func recordToolResultTimestamp(result *stepResult, toolCallID string, ts time.Time) {
if result.toolResultCreatedAt == nil {
result.toolResultCreatedAt = make(map[string]time.Time)
}
result.toolResultCreatedAt[toolCallID] = ts
}
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
if len(metadata) == 0 {
return sql.NullInt64{}
+238 -1
View File
@@ -86,6 +86,54 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4]))
}
func TestProcessStepStream_AnthropicUsageMatchesFinalDelta(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: fantasyanthropic.Name,
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "cached response"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{
Type: fantasy.StreamPartTypeFinish,
Usage: fantasy.Usage{
InputTokens: 200,
OutputTokens: 75,
TotalTokens: 275,
CacheCreationTokens: 30,
CacheReadTokens: 150,
ReasoningTokens: 0,
},
FinishReason: fantasy.FinishReasonStop,
},
}), nil
},
}
var persistedStep PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
ContextLimitFallback: 4096,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistedStep = step
return nil
},
})
require.NoError(t, err)
require.Equal(t, int64(200), persistedStep.Usage.InputTokens)
require.Equal(t, int64(75), persistedStep.Usage.OutputTokens)
require.Equal(t, int64(275), persistedStep.Usage.TotalTokens)
require.Equal(t, int64(30), persistedStep.Usage.CacheCreationTokens)
require.Equal(t, int64(150), persistedStep.Usage.CacheReadTokens)
}
func TestRun_OnRetryEnrichesProvider(t *testing.T) {
t.Parallel()
@@ -535,6 +583,7 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
persistedAssistantCtxErr := xerrors.New("unset")
var persistedContent []fantasy.Content
var persistedStep PersistedStep
err := Run(ctx, RunOptions{
Model: model,
@@ -548,6 +597,7 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
persistedAssistantCtxErr = persistCtx.Err()
persistedContent = append([]fantasy.Content(nil), step.Content...)
persistedStep = step
return nil
},
})
@@ -587,6 +637,14 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
require.True(t, foundText)
require.True(t, foundToolCall)
require.True(t, foundToolResult)
// The interrupted tool was flushed mid-stream (never reached
// StreamPartTypeToolCall), so it has no call timestamp.
// But the synthetic error result must have a result timestamp.
require.Contains(t, persistedStep.ToolResultCreatedAt, "interrupt-tool-1",
"interrupted tool result must have a result timestamp")
require.NotContains(t, persistedStep.ToolCallCreatedAt, "interrupt-tool-1",
"interrupted tool should have no call timestamp (never reached StreamPartTypeToolCall)")
}
type loopTestModel struct {
@@ -727,6 +785,7 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
}
var persistStepCalls int
var persistedSteps []PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
@@ -736,8 +795,9 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
newNoopTool("read_file"),
},
MaxSteps: 5,
PersistStep: func(_ context.Context, _ PersistedStep) error {
PersistStep: func(_ context.Context, step PersistedStep) error {
persistStepCalls++
persistedSteps = append(persistedSteps, step)
return nil
},
})
@@ -778,6 +838,112 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
}
require.True(t, foundAssistantToolCall, "second call prompt should contain assistant tool call from step 0")
require.True(t, foundToolResult, "second call prompt should contain tool result message")
// The first persisted step (tool-call step) must carry
// accurate timestamps for duration computation.
require.Len(t, persistedSteps, 2)
toolStep := persistedSteps[0]
require.Contains(t, toolStep.ToolCallCreatedAt, "tc-1",
"tool-call step must record when the model emitted the call")
require.Contains(t, toolStep.ToolResultCreatedAt, "tc-1",
"tool-call step must record when the tool result was produced")
require.False(t, toolStep.ToolResultCreatedAt["tc-1"].Before(toolStep.ToolCallCreatedAt["tc-1"]),
"tool-result timestamp must be >= tool-call timestamp")
}
func TestRun_ParallelToolExecutionTimestamps(t *testing.T) {
t.Parallel()
var mu sync.Mutex
var streamCalls int
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCalls
streamCalls++
mu.Unlock()
_ = call
switch step {
case 0:
// Step 0: produce two tool calls in one stream.
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"path":"a.go"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "read_file",
ToolCallInput: `{"path":"a.go"}`,
},
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-2", ToolCallName: "write_file"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-2", Delta: `{"path":"b.go"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-2"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-2",
ToolCallName: "write_file",
ToolCallInput: `{"path":"b.go"}`,
},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
}), nil
default:
// Step 1: return plain text.
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "all done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
}
},
}
var persistedSteps []PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "do both"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
newNoopTool("write_file"),
},
MaxSteps: 5,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistedSteps = append(persistedSteps, step)
return nil
},
})
require.NoError(t, err)
// Two steps: tool-call step + text step.
require.Equal(t, 2, streamCalls)
require.Len(t, persistedSteps, 2)
toolStep := persistedSteps[0]
// Both tool-call IDs must appear in ToolCallCreatedAt.
require.Contains(t, toolStep.ToolCallCreatedAt, "tc-1",
"tool-call step must record when tc-1 was emitted")
require.Contains(t, toolStep.ToolCallCreatedAt, "tc-2",
"tool-call step must record when tc-2 was emitted")
// Both tool-call IDs must appear in ToolResultCreatedAt.
require.Contains(t, toolStep.ToolResultCreatedAt, "tc-1",
"tool-call step must record when tc-1 result was produced")
require.Contains(t, toolStep.ToolResultCreatedAt, "tc-2",
"tool-call step must record when tc-2 result was produced")
// Result timestamps must be >= call timestamps for both.
require.False(t, toolStep.ToolResultCreatedAt["tc-1"].Before(toolStep.ToolCallCreatedAt["tc-1"]),
"tc-1 tool-result timestamp must be >= tool-call timestamp")
require.False(t, toolStep.ToolResultCreatedAt["tc-2"].Before(toolStep.ToolCallCreatedAt["tc-2"]),
"tc-2 tool-result timestamp must be >= tool-call timestamp")
}
func TestRun_PersistStepErrorPropagates(t *testing.T) {
@@ -1183,6 +1349,77 @@ func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
require.True(t, foundToolResult, "persisted content should include the tool result (error from cancellation)")
}
// TestRun_ProviderExecutedToolResultTimestamps verifies that
// provider-executed tool results (e.g. web search) have their
// timestamps recorded in PersistedStep.ToolResultCreatedAt so
// the persistence layer can stamp CreatedAt on the parts.
func TestRun_ProviderExecutedToolResultTimestamps(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
// Simulate a provider-executed tool call and result
// (e.g. Anthropic web search) followed by a text
// response — all in a single stream.
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "ws-1", ToolCallName: "web_search", ProviderExecuted: true},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "ws-1", Delta: `{"query":"coder"}`, ProviderExecuted: true},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "ws-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "ws-1",
ToolCallName: "web_search",
ToolCallInput: `{"query":"coder"}`,
ProviderExecuted: true,
},
// Provider-executed tool result — emitted by
// the provider, not our tool runner.
{
Type: fantasy.StreamPartTypeToolResult,
ID: "ws-1",
ToolCallName: "web_search",
ProviderExecuted: true,
},
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "search done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
var persistedSteps []PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "search for coder"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistedSteps = append(persistedSteps, step)
return nil
},
})
require.NoError(t, err)
require.Len(t, persistedSteps, 1)
step := persistedSteps[0]
// Provider-executed tool call should have a call timestamp.
require.Contains(t, step.ToolCallCreatedAt, "ws-1",
"provider-executed tool call must record its timestamp")
// Provider-executed tool result should have a result
// timestamp so the frontend can compute duration.
require.Contains(t, step.ToolResultCreatedAt, "ws-1",
"provider-executed tool result must record its timestamp")
require.False(t,
step.ToolResultCreatedAt["ws-1"].Before(step.ToolCallCreatedAt["ws-1"]),
"tool-result timestamp must be >= tool-call timestamp")
}
// TestRun_PersistStepInterruptedFallback verifies that when the normal
// PersistStep call returns ErrInterrupted (e.g., context canceled in a
// race), the step is retried via the interrupt-safe path.
@@ -713,4 +713,76 @@ func TestRun_Compaction(t *testing.T) {
}
require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)")
})
t.Run("TriggersOnDynamicToolExit", func(t *testing.T) {
t.Parallel()
var persistCompactionCalls int
const summaryText = "compaction summary for dynamic tool exit"
// The LLM calls a dynamic tool. Usage is above the
// compaction threshold so compaction should fire even
// though the chatloop exits via ErrDynamicToolCall.
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "my_dynamic_tool"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"query": "test"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "my_dynamic_tool",
ToolCallInput: `{"query": "test"}`,
},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonToolCalls,
Usage: fantasy.Usage{
InputTokens: 80,
TotalTokens: 85,
},
},
}), nil
},
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
},
}, nil
},
}
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 5,
DynamicToolNames: map[string]bool{"my_dynamic_tool": true},
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
SummaryPrompt: "summarize now",
Persist: func(_ context.Context, result CompactionResult) error {
persistCompactionCalls++
require.Contains(t, result.SystemSummary, summaryText)
return nil
},
},
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
return []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
}, nil
},
})
require.ErrorIs(t, err, ErrDynamicToolCall)
require.Equal(t, 1, persistCompactionCalls,
"compaction must fire before dynamic tool exit")
})
}
@@ -2329,3 +2329,48 @@ func TestMediaToolResultRoundTrip(t *testing.T) {
require.True(t, isText, "expected ToolResultOutputContentText")
})
}
func TestPartFromContent_CreatedAtNotStamped(t *testing.T) {
t.Parallel()
// PartFromContent must NOT stamp CreatedAt itself.
// The chatloop layer records timestamps separately and
// the persistence layer applies them. PartFromContent
// is called in multiple contexts (SSE publishing,
// persistence) so stamping inside it would produce
// inaccurate durations.
t.Run("ToolCallHasNilCreatedAt", func(t *testing.T) {
t.Parallel()
part := chatprompt.PartFromContent(fantasy.ToolCallContent{
ToolCallID: "tc-1",
ToolName: "execute",
})
assert.Nil(t, part.CreatedAt)
})
t.Run("ToolCallPointerHasNilCreatedAt", func(t *testing.T) {
t.Parallel()
part := chatprompt.PartFromContent(&fantasy.ToolCallContent{
ToolCallID: "tc-1",
ToolName: "execute",
})
assert.Nil(t, part.CreatedAt)
})
t.Run("ToolResultHasNilCreatedAt", func(t *testing.T) {
t.Parallel()
part := chatprompt.PartFromContent(fantasy.ToolResultContent{
ToolCallID: "tc-1",
ToolName: "execute",
Result: fantasy.ToolResultOutputContentText{Text: "{}"},
})
assert.Nil(t, part.CreatedAt)
})
t.Run("TextHasNilCreatedAt", func(t *testing.T) {
t.Parallel()
part := chatprompt.PartFromContent(fantasy.TextContent{Text: "hello"})
assert.Nil(t, part.CreatedAt)
})
}
+89 -7
View File
@@ -53,8 +53,10 @@ type AnthropicMessage struct {
// AnthropicUsage represents usage information in an Anthropic response.
type AnthropicUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}
// AnthropicChunk represents a streaming chunk from Anthropic.
@@ -67,14 +69,16 @@ type AnthropicChunk struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence *string `json:"stop_sequence,omitempty"`
Usage AnthropicUsage `json:"usage,omitempty"`
UsageMap map[string]int `json:"-"`
}
// AnthropicChunkMessage represents message metadata in a chunk.
type AnthropicChunkMessage struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
Usage map[string]int `json:"usage,omitempty"`
}
// AnthropicContentBlock represents a content block in a chunk.
@@ -206,7 +210,11 @@ func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <
"stop_reason": chunk.StopReason,
"stop_sequence": chunk.StopSequence,
}
chunkData["usage"] = chunk.Usage
if chunk.UsageMap != nil {
chunkData["usage"] = chunk.UsageMap
} else {
chunkData["usage"] = chunk.Usage
}
case "message_stop":
// No additional fields
}
@@ -342,6 +350,80 @@ func AnthropicTextChunks(deltas ...string) []AnthropicChunk {
return chunks
}
// AnthropicTextChunksWithCacheUsage creates a streaming response with text
// deltas and explicit cache token usage. The message_start event carries
// the initial input and cache token counts, and the final message_delta
// carries the output token count.
func AnthropicTextChunksWithCacheUsage(usage AnthropicUsage, deltas ...string) []AnthropicChunk {
if len(deltas) == 0 {
return nil
}
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
model := "claude-3-opus-20240229"
messageUsage := map[string]int{
"input_tokens": usage.InputTokens,
}
if usage.CacheCreationInputTokens != 0 {
messageUsage["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
}
if usage.CacheReadInputTokens != 0 {
messageUsage["cache_read_input_tokens"] = usage.CacheReadInputTokens
}
chunks := []AnthropicChunk{
{
Type: "message_start",
Message: AnthropicChunkMessage{
ID: messageID,
Type: "message",
Role: "assistant",
Model: model,
Usage: messageUsage,
},
},
{
Type: "content_block_start",
Index: 0,
ContentBlock: AnthropicContentBlock{
Type: "text",
Text: "",
},
},
}
for _, delta := range deltas {
chunks = append(chunks, AnthropicChunk{
Type: "content_block_delta",
Index: 0,
Delta: AnthropicDeltaBlock{
Type: "text_delta",
Text: delta,
},
})
}
chunks = append(chunks,
AnthropicChunk{
Type: "content_block_stop",
Index: 0,
},
AnthropicChunk{
Type: "message_delta",
StopReason: "end_turn",
UsageMap: map[string]int{
"output_tokens": usage.OutputTokens,
},
},
AnthropicChunk{
Type: "message_stop",
},
)
return chunks
}
// AnthropicToolCallChunks creates a complete streaming response for a tool call.
// Input JSON can be split across multiple deltas, matching Anthropic's
// input_json_delta streaming behavior.
+53
View File
@@ -63,6 +63,59 @@ func TestAnthropic_Streaming(t *testing.T) {
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d. Total parts received: %d", len(expectedDeltas), deltaIndex, len(allParts))
}
func TestAnthropic_StreamingUsageIncludesCacheTokens(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunksWithCacheUsage(chattest.AnthropicUsage{
InputTokens: 200,
OutputTokens: 75,
CacheCreationInputTokens: 30,
CacheReadInputTokens: 150,
}, "cached", " response")...,
)
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.NoError(t, err)
var (
finishPart fantasy.StreamPart
found bool
)
for part := range stream {
if part.Type != fantasy.StreamPartTypeFinish {
continue
}
finishPart = part
found = true
}
require.True(t, found)
require.Equal(t, int64(200), finishPart.Usage.InputTokens)
require.Equal(t, int64(75), finishPart.Usage.OutputTokens)
require.Equal(t, int64(275), finishPart.Usage.TotalTokens)
require.Equal(t, int64(30), finishPart.Usage.CacheCreationTokens)
require.Equal(t, int64(150), finishPart.Usage.CacheReadTokens)
}
func TestAnthropic_ToolCalls(t *testing.T) {
t.Parallel()
+91
View File
@@ -0,0 +1,91 @@
package chatd
import (
"context"
"encoding/json"
"charm.land/fantasy"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/codersdk"
)
// dynamicTool wraps a codersdk.DynamicTool as a fantasy.AgentTool.
// These tools are presented to the LLM but never executed by the
// chatloop — when the LLM calls one, the chatloop exits with
// requires_action status and the client handles execution.
// The Run method should never be called; it returns an error if
// it is, as a safety net.
type dynamicTool struct {
name string
description string
parameters map[string]any
required []string
opts fantasy.ProviderOptions
}
// dynamicToolsFromSDK converts codersdk.DynamicTool definitions
// into fantasy.AgentTool implementations for inclusion in the LLM
// tool list.
func dynamicToolsFromSDK(logger slog.Logger, tools []codersdk.DynamicTool) []fantasy.AgentTool {
if len(tools) == 0 {
return nil
}
result := make([]fantasy.AgentTool, 0, len(tools))
for _, t := range tools {
dt := &dynamicTool{
name: t.Name,
description: t.Description,
}
// InputSchema is a full JSON Schema object stored as
// json.RawMessage. Extract the "properties" and
// "required" fields that fantasy.ToolInfo expects.
if len(t.InputSchema) > 0 {
var schema struct {
Properties map[string]any `json:"properties"`
Required []string `json:"required"`
}
if err := json.Unmarshal(t.InputSchema, &schema); err != nil {
// Defensive: present the tool with no parameter
// constraints rather than failing. The LLM may
// hallucinate argument shapes, but the tool will
// still appear in the tool list.
logger.Warn(context.Background(), "failed to parse dynamic tool input schema",
slog.F("tool_name", t.Name),
slog.Error(err))
} else {
dt.parameters = schema.Properties
dt.required = schema.Required
}
}
result = append(result, dt)
}
return result
}
func (t *dynamicTool) Info() fantasy.ToolInfo {
return fantasy.ToolInfo{
Name: t.name,
Description: t.description,
Parameters: t.parameters,
Required: t.required,
}
}
func (*dynamicTool) Run(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
// Dynamic tools are never executed by the chatloop. If this
// method is called, it indicates a bug in the chatloop's
// dynamic tool detection logic.
return fantasy.NewTextErrorResponse(
"dynamic tool called in chatloop — this is a bug; " +
"dynamic tools should be handled by the client",
), nil
}
func (t *dynamicTool) ProviderOptions() fantasy.ProviderOptions {
return t.opts
}
func (t *dynamicTool) SetProviderOptions(opts fantasy.ProviderOptions) {
t.opts = opts
}
+114
View File
@@ -0,0 +1,114 @@
package chatd
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/codersdk"
)
func TestDynamicToolsFromSDK(t *testing.T) {
t.Parallel()
t.Run("EmptySlice", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
result := dynamicToolsFromSDK(logger, nil)
require.Nil(t, result)
})
t.Run("ValidToolWithSchema", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{
Name: "my_tool",
Description: "A useful tool",
InputSchema: json.RawMessage(`{"type":"object","properties":{"input":{"type":"string"}},"required":["input"]}`),
},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 1)
info := result[0].Info()
require.Equal(t, "my_tool", info.Name)
require.Equal(t, "A useful tool", info.Description)
require.NotNil(t, info.Parameters)
require.Contains(t, info.Parameters, "input")
require.Equal(t, []string{"input"}, info.Required)
})
t.Run("ToolWithoutSchema", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{
Name: "no_schema",
Description: "Tool with no schema",
},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 1)
info := result[0].Info()
require.Equal(t, "no_schema", info.Name)
require.Nil(t, info.Parameters)
require.Nil(t, info.Required)
})
t.Run("MalformedSchema", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{
Name: "bad_schema",
Description: "Tool with malformed schema",
InputSchema: json.RawMessage("not-json"),
},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 1)
info := result[0].Info()
require.Equal(t, "bad_schema", info.Name)
require.Nil(t, info.Parameters)
require.Nil(t, info.Required)
})
t.Run("MultipleTools", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{Name: "first", Description: "First tool"},
{Name: "second", Description: "Second tool"},
{Name: "third", Description: "Third tool"},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 3)
require.Equal(t, "first", result[0].Info().Name)
require.Equal(t, "second", result[1].Info().Name)
require.Equal(t, "third", result[2].Info().Name)
})
t.Run("SchemaWithoutProperties", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{
Name: "bare_schema",
Description: "Schema with no properties",
InputSchema: json.RawMessage(`{"type":"object"}`),
},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 1)
info := result[0].Info()
require.Equal(t, "bare_schema", info.Name)
require.Nil(t, info.Parameters)
require.Nil(t, info.Required)
})
}
+24 -8
View File
@@ -180,7 +180,7 @@ func generateTitle(
model fantasy.LanguageModel,
input string,
) (string, error) {
title, _, err := generateStructuredTitle(ctx, model, titleGenerationPrompt, input)
title, err := generateStructuredTitle(ctx, model, titleGenerationPrompt, input)
if err != nil {
return "", err
}
@@ -192,6 +192,24 @@ func generateStructuredTitle(
model fantasy.LanguageModel,
systemPrompt string,
userInput string,
) (string, error) {
title, _, err := generateStructuredTitleWithUsage(
ctx,
model,
systemPrompt,
userInput,
)
if err != nil {
return "", err
}
return title, nil
}
func generateStructuredTitleWithUsage(
ctx context.Context,
model fantasy.LanguageModel,
systemPrompt string,
userInput string,
) (string, fantasy.Usage, error) {
userInput = strings.TrimSpace(userInput)
if userInput == "" {
@@ -226,8 +244,6 @@ func generateStructuredTitle(
return genErr
}, nil)
if err != nil {
// Extract usage from the error when available so that
// failed attempts are still accounted for in usage tracking.
var usage fantasy.Usage
var noObjErr *fantasy.NoObjectGeneratedError
if errors.As(err, &noObjErr) {
@@ -529,7 +545,7 @@ func generateManualTitle(
userInput = strings.TrimSpace(firstUserText)
}
title, usage, err := generateStructuredTitle(
title, usage, err := generateStructuredTitleWithUsage(
titleCtx,
fallbackModel,
systemPrompt,
@@ -579,7 +595,7 @@ func generatePushSummary(
candidates = append(candidates, fallbackModel)
for _, model := range candidates {
summary, _, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
if err != nil {
logger.Debug(ctx, "push summary model candidate failed",
slog.Error(err),
@@ -601,7 +617,7 @@ func generateShortText(
model fantasy.LanguageModel,
systemPrompt string,
userInput string,
) (string, fantasy.Usage, error) {
) (string, error) {
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
@@ -629,7 +645,7 @@ func generateShortText(
return genErr
}, nil)
if err != nil {
return "", fantasy.Usage{}, xerrors.Errorf("generate short text: %w", err)
return "", xerrors.Errorf("generate short text: %w", err)
}
responseParts := make([]codersdk.ChatMessagePart, 0, len(response.Content))
@@ -639,5 +655,5 @@ func generateShortText(
}
}
text := normalizeShortTextOutput(contentBlocksToText(responseParts))
return text, response.Usage, nil
return text, nil
}
+1 -4
View File
@@ -515,12 +515,9 @@ func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) {
},
}
text, usage, err := generateShortText(context.Background(), model, "system", "user")
text, err := generateShortText(context.Background(), model, "system", "user")
require.NoError(t, err)
require.Equal(t, "Quoted summary", text)
require.Equal(t, int64(3), usage.InputTokens)
require.Equal(t, int64(2), usage.OutputTokens)
require.Equal(t, int64(5), usage.TotalTokens)
}
type stubModel struct {
+1 -1
View File
@@ -376,7 +376,7 @@ func ProtoFromLog(log Log) (*proto.Log, error) {
}
return &proto.Log{
CreatedAt: timestamppb.New(log.CreatedAt),
Output: strings.ToValidUTF8(log.Output, "❌"),
Output: SanitizeLogOutput(log.Output),
Level: proto.Log_Level(lvl),
}, nil
}
+6 -6
View File
@@ -229,7 +229,7 @@ func TestLogSender_SkipHugeLog(t *testing.T) {
require.ErrorIs(t, err, context.Canceled)
}
func TestLogSender_InvalidUTF8(t *testing.T) {
func TestLogSender_SanitizeOutput(t *testing.T) {
t.Parallel()
testCtx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(testCtx)
@@ -243,7 +243,7 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
uut.Enqueue(ls1,
Log{
CreatedAt: t0,
Output: "test log 0, src 1\xc3\x28",
Output: "test log 0, src 1\x00\xc3\x28",
Level: codersdk.LogLevelInfo,
},
Log{
@@ -260,10 +260,10 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
req := testutil.TryReceive(ctx, t, fDest.reqs)
require.NotNil(t, req)
require.Len(t, req.Logs, 2, "it should sanitize invalid UTF-8, but still send")
// the 0xc3, 0x28 is an invalid 2-byte sequence in UTF-8. The sanitizer replaces 0xc3 with ❌, and then
// interprets 0x28 as a 1-byte sequence "("
require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput())
require.Len(t, req.Logs, 2, "it should sanitize invalid output, but still send")
// The sanitizer replaces the NUL byte and invalid UTF-8 with ❌ while
// preserving the valid "(" byte that follows 0xc3.
require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput())
require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel())
require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput())
require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel())
+11
View File
@@ -0,0 +1,11 @@
package agentsdk
import "strings"
// SanitizeLogOutput replaces invalid UTF-8 and NUL characters in log output.
// Invalid UTF-8 cannot be transported in protobuf string fields, and PostgreSQL
// rejects NUL bytes in text columns.
func SanitizeLogOutput(s string) string {
s = strings.ToValidUTF8(s, "❌")
return strings.ReplaceAll(s, "\x00", "❌")
}
+48
View File
@@ -17,6 +17,54 @@ import (
"github.com/coder/coder/v2/testutil"
)
func TestSanitizeLogOutput(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want string
}{
{
name: "valid",
in: "hello world",
want: "hello world",
},
{
name: "invalid utf8",
in: "test log\xc3\x28",
want: "test log❌(",
},
{
name: "nul byte",
in: "before\x00after",
want: "before❌after",
},
{
name: "invalid utf8 and nul byte",
in: "before\x00middle\xc3\x28after",
want: "before❌middle❌(after",
},
{
name: "nul byte at edges",
in: "\x00middle\x00",
want: "❌middle❌",
},
{
name: "invalid utf8 at edges",
in: "\xc3middle\xc3",
want: "❌middle❌",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, agentsdk.SanitizeLogOutput(tt.in))
})
}
}
func TestStartupLogsWriter_Write(t *testing.T) {
t.Parallel()
+269 -20
View File
@@ -15,6 +15,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/invopop/jsonschema"
"github.com/shopspring/decimal"
"golang.org/x/xerrors"
@@ -42,12 +43,13 @@ func CompactionThresholdKey(modelConfigID uuid.UUID) string {
type ChatStatus string
const (
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
ChatStatusRequiresAction ChatStatus = "requires_action"
)
// Chat represents a chat session with an AI agent.
@@ -212,6 +214,10 @@ type ChatMessagePart struct {
// ProviderExecuted indicates the tool call was executed by
// the provider (e.g. Anthropic computer use).
ProviderExecuted bool `json:"provider_executed,omitempty" variants:"tool-call?,tool-result?"`
// CreatedAt records when this part was produced. Present on
// tool-call and tool-result parts so the frontend can compute
// tool execution duration.
CreatedAt *time.Time `json:"created_at,omitempty" format:"date-time" variants:"tool-call?,tool-result?"`
// ContextFilePath is the absolute path of a file loaded into
// the LLM context (e.g. an AGENTS.md instruction file).
ContextFilePath string `json:"context_file_path" variants:"context-file"`
@@ -361,6 +367,18 @@ type ChatInputPart struct {
Content string `json:"content,omitempty"`
}
// SubmitToolResultsRequest is the body for POST /chats/{id}/tool-results.
type SubmitToolResultsRequest struct {
Results []ToolResult `json:"results"`
}
// ToolResult is the client's response to a dynamic tool call.
type ToolResult struct {
ToolCallID string `json:"tool_call_id"`
Output json.RawMessage `json:"output"`
IsError bool `json:"is_error"`
}
// CreateChatRequest is the request to create a new chat.
type CreateChatRequest struct {
Content []ChatInputPart `json:"content"`
@@ -369,6 +387,10 @@ type CreateChatRequest struct {
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
Labels map[string]string `json:"labels,omitempty"`
// UnsafeDynamicTools declares client-executed tools that the
// LLM can invoke. This API is highly experimental and highly
// subject to change.
UnsafeDynamicTools []DynamicTool `json:"unsafe_dynamic_tools,omitempty"`
}
// UpdateChatRequest is the request to update a chat.
@@ -545,6 +567,17 @@ type UpdateChatWorkspaceTTLRequest struct {
WorkspaceTTLMillis int64 `json:"workspace_ttl_ms"`
}
// ChatRetentionDaysResponse contains the current chat retention setting.
type ChatRetentionDaysResponse struct {
RetentionDays int32 `json:"retention_days"`
}
// UpdateChatRetentionDaysRequest is a request to update the chat
// retention period.
type UpdateChatRetentionDaysRequest struct {
RetentionDays int32 `json:"retention_days"`
}
// ParseChatWorkspaceTTL parses a stored TTL string, returning the
// default when the value is empty.
func ParseChatWorkspaceTTL(s string) (time.Duration, error) {
@@ -917,12 +950,13 @@ type ChatDiffContents struct {
type ChatStreamEventType string
const (
ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part"
ChatStreamEventTypeMessage ChatStreamEventType = "message"
ChatStreamEventTypeStatus ChatStreamEventType = "status"
ChatStreamEventTypeError ChatStreamEventType = "error"
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
ChatStreamEventTypeRetry ChatStreamEventType = "retry"
ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part"
ChatStreamEventTypeMessage ChatStreamEventType = "message"
ChatStreamEventTypeStatus ChatStreamEventType = "status"
ChatStreamEventTypeError ChatStreamEventType = "error"
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
ChatStreamEventTypeRetry ChatStreamEventType = "retry"
ChatStreamEventTypeActionRequired ChatStreamEventType = "action_required"
)
// ChatQueuedMessage represents a queued message waiting to be processed.
@@ -977,16 +1011,123 @@ type ChatStreamRetry struct {
RetryingAt time.Time `json:"retrying_at" format:"date-time"`
}
// ChatStreamActionRequired is the payload of an action_required stream event.
type ChatStreamActionRequired struct {
ToolCalls []ChatStreamToolCall `json:"tool_calls"`
}
// ChatStreamToolCall describes a pending dynamic tool call that the client
// must execute.
type ChatStreamToolCall struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Args string `json:"args"`
}
// DynamicToolCall represents a pending tool invocation from the
// chat stream that the client must execute and submit back.
type DynamicToolCall struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Args string `json:"args"`
}
// DynamicToolResponse holds the output of a dynamic tool
// execution. IsError indicates a tool-level error the LLM
// should see, as opposed to an infrastructure failure
// (returned as the error return value).
type DynamicToolResponse struct {
Content string `json:"content"`
IsError bool `json:"is_error"`
}
// DynamicTool describes a client-declared tool definition. On the
// client side, the Handler callback executes the tool when the LLM
// invokes it. On the server side, only Name, Description, and
// InputSchema are used (Handler is not serialized).
type DynamicTool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
// InputSchema's JSON key "input_schema" uses snake_case for
// SDK consistency, deviating from the camelCase "inputSchema"
// convention used by MCP.
InputSchema json.RawMessage `json:"input_schema"`
// Handler executes the tool when the LLM invokes it.
// Not serialized — this only exists on the client side.
Handler func(ctx context.Context, call DynamicToolCall) (DynamicToolResponse, error) `json:"-"`
}
// NewDynamicTool creates a DynamicTool with a typed handler.
// The JSON schema is derived from T using invopop/jsonschema.
// The handler receives deserialized args and the DynamicToolCall metadata.
func NewDynamicTool[T any](
name, description string,
handler func(ctx context.Context, args T, call DynamicToolCall) (DynamicToolResponse, error),
) DynamicTool {
reflector := jsonschema.Reflector{
DoNotReference: true,
Anonymous: true,
AllowAdditionalProperties: true,
}
schema := reflector.Reflect(new(T))
schema.Version = ""
schemaJSON, err := json.Marshal(schema)
if err != nil {
panic(fmt.Sprintf("codersdk: failed to marshal schema for %q: %v", name, err))
}
return DynamicTool{
Name: name,
Description: description,
InputSchema: schemaJSON,
Handler: func(ctx context.Context, call DynamicToolCall) (DynamicToolResponse, error) {
var parsed T
if err := json.Unmarshal([]byte(call.Args), &parsed); err != nil {
return DynamicToolResponse{
Content: fmt.Sprintf("invalid parameters: %s", err),
IsError: true,
}, nil
}
return handler(ctx, parsed, call)
},
}
}
// ChatWatchEventKind represents the kind of event in the chat watch stream.
type ChatWatchEventKind string
const (
ChatWatchEventKindStatusChange ChatWatchEventKind = "status_change"
ChatWatchEventKindTitleChange ChatWatchEventKind = "title_change"
ChatWatchEventKindCreated ChatWatchEventKind = "created"
ChatWatchEventKindDeleted ChatWatchEventKind = "deleted"
ChatWatchEventKindDiffStatusChange ChatWatchEventKind = "diff_status_change"
ChatWatchEventKindActionRequired ChatWatchEventKind = "action_required"
)
// ChatWatchEvent represents an event from the global chat watch stream.
// It delivers lifecycle events (created, status change, title change)
// for all of the authenticated user's chats. When Kind is
// ActionRequired, ToolCalls contains the pending dynamic tool
// invocations the client must execute and submit back.
type ChatWatchEvent struct {
Kind ChatWatchEventKind `json:"kind"`
Chat Chat `json:"chat"`
ToolCalls []ChatStreamToolCall `json:"tool_calls,omitempty"`
}
// ChatStreamEvent represents a real-time update for chat streaming.
type ChatStreamEvent struct {
Type ChatStreamEventType `json:"type"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
Message *ChatMessage `json:"message,omitempty"`
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
Status *ChatStreamStatus `json:"status,omitempty"`
Error *ChatStreamError `json:"error,omitempty"`
Retry *ChatStreamRetry `json:"retry,omitempty"`
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
Type ChatStreamEventType `json:"type"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
Message *ChatMessage `json:"message,omitempty"`
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
Status *ChatStreamStatus `json:"status,omitempty"`
Error *ChatStreamError `json:"error,omitempty"`
Retry *ChatStreamRetry `json:"retry,omitempty"`
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
ActionRequired *ChatStreamActionRequired `json:"action_required,omitempty"`
}
type chatStreamEnvelope struct {
@@ -1667,6 +1808,33 @@ func (c *ExperimentalClient) UpdateChatWorkspaceTTL(ctx context.Context, req Upd
return nil
}
// GetChatRetentionDays returns the configured chat retention period.
func (c *ExperimentalClient) GetChatRetentionDays(ctx context.Context) (ChatRetentionDaysResponse, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/retention-days", nil)
if err != nil {
return ChatRetentionDaysResponse{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatRetentionDaysResponse{}, ReadBodyAsError(res)
}
var resp ChatRetentionDaysResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// UpdateChatRetentionDays updates the chat retention period.
func (c *ExperimentalClient) UpdateChatRetentionDays(ctx context.Context, req UpdateChatRetentionDaysRequest) error {
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/retention-days", req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// GetChatTemplateAllowlist returns the deployment-wide chat template allowlist.
func (c *ExperimentalClient) GetChatTemplateAllowlist(ctx context.Context) (ChatTemplateAllowlist, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/template-allowlist", nil)
@@ -1902,6 +2070,73 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
}), nil
}
// WatchChats streams lifecycle events for all of the authenticated
// user's chats in real time. The returned channel emits
// ChatWatchEvent values for status changes, title changes, creation,
// deletion, diff-status changes, and action-required notifications.
// Callers must close the returned io.Closer to release the websocket
// connection when done.
func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEvent, io.Closer, error) {
conn, err := c.Dial(
ctx,
"/api/experimental/chats/watch",
&websocket.DialOptions{CompressionMode: websocket.CompressionDisabled},
)
if err != nil {
return nil, nil, err
}
conn.SetReadLimit(1 << 22) // 4MiB
streamCtx, streamCancel := context.WithCancel(ctx)
events := make(chan ChatWatchEvent, 128)
go func() {
defer close(events)
defer streamCancel()
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
for {
var envelope chatStreamEnvelope
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
if streamCtx.Err() != nil {
return
}
switch websocket.CloseStatus(err) {
case websocket.StatusNormalClosure, websocket.StatusGoingAway:
return
}
return
}
switch envelope.Type {
case ServerSentEventTypePing:
continue
case ServerSentEventTypeData:
var event ChatWatchEvent
if err := json.Unmarshal(envelope.Data, &event); err != nil {
return
}
select {
case <-streamCtx.Done():
return
case events <- event:
}
case ServerSentEventTypeError:
return
default:
return
}
}
}()
return events, closeFunc(func() error {
streamCancel()
return nil
}), nil
}
// GetChat returns a chat by ID.
func (c *ExperimentalClient) GetChat(ctx context.Context, chatID uuid.UUID) (Chat, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s", chatID), nil)
@@ -2209,6 +2444,20 @@ func (c *ExperimentalClient) GetMyChatUsageLimitStatus(ctx context.Context) (Cha
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// SubmitToolResults submits the results of dynamic tool calls for a chat
// that is in requires_action status.
func (c *ExperimentalClient) SubmitToolResults(ctx context.Context, chatID uuid.UUID, req SubmitToolResultsRequest) error {
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/tool-results", chatID), req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// GetChatsByWorkspace returns a mapping of workspace ID to the latest
// non-archived chat ID for each requested workspace. Workspaces with
// no chats are omitted from the response.
+98
View File
@@ -329,6 +329,42 @@ func TestChatMessagePartVariantTags(t *testing.T) {
})
}
func TestChatMessagePart_CreatedAt_JSON(t *testing.T) {
t.Parallel()
t.Run("RoundTrips", func(t *testing.T) {
t.Parallel()
ts := time.Date(2025, 6, 15, 12, 30, 0, 0, time.UTC)
part := codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: "tc-1",
ToolName: "execute",
CreatedAt: &ts,
}
data, err := json.Marshal(part)
require.NoError(t, err)
require.Contains(t, string(data), `"created_at"`)
var decoded codersdk.ChatMessagePart
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
require.NotNil(t, decoded.CreatedAt)
require.True(t, ts.Equal(*decoded.CreatedAt))
})
t.Run("OmittedWhenNil", func(t *testing.T) {
t.Parallel()
part := codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: "tc-1",
ToolName: "execute",
}
data, err := json.Marshal(part)
require.NoError(t, err)
require.NotContains(t, string(data), `"created_at"`)
})
}
func TestModelCostConfig_LegacyNumericJSON(t *testing.T) {
t.Parallel()
@@ -469,6 +505,68 @@ func TestChat_JSONRoundTrip(t *testing.T) {
require.Equal(t, original, decoded)
}
func TestNewDynamicTool(t *testing.T) {
t.Parallel()
type testArgs struct {
Query string `json:"query"`
}
t.Run("CorrectSchema", func(t *testing.T) {
t.Parallel()
tool := codersdk.NewDynamicTool(
"search", "search things",
func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) {
return codersdk.DynamicToolResponse{Content: args.Query}, nil
},
)
require.Equal(t, "search", tool.Name)
require.Equal(t, "search things", tool.Description)
require.Contains(t, string(tool.InputSchema), `"query"`)
require.Contains(t, string(tool.InputSchema), `"string"`)
})
t.Run("HandlerReceivesArgs", func(t *testing.T) {
t.Parallel()
var received testArgs
tool := codersdk.NewDynamicTool(
"search", "search things",
func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) {
received = args
return codersdk.DynamicToolResponse{Content: "ok"}, nil
},
)
resp, err := tool.Handler(context.Background(), codersdk.DynamicToolCall{
Args: `{"query":"hello"}`,
})
require.NoError(t, err)
require.Equal(t, "ok", resp.Content)
require.Equal(t, "hello", received.Query)
})
t.Run("InvalidJSONArgs", func(t *testing.T) {
t.Parallel()
tool := codersdk.NewDynamicTool(
"search", "search things",
func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) {
return codersdk.DynamicToolResponse{Content: "should not reach"}, nil
},
)
resp, err := tool.Handler(context.Background(), codersdk.DynamicToolCall{
Args: "not-json",
})
require.NoError(t, err)
require.True(t, resp.IsError)
require.Contains(t, resp.Content, "invalid parameters")
})
}
//nolint:tparallel,paralleltest
func TestParseChatWorkspaceTTL(t *testing.T) {
t.Parallel()
+30 -2
View File
@@ -196,6 +196,7 @@ const (
FeatureWorkspaceExternalAgent FeatureName = "workspace_external_agent"
FeatureAIBridge FeatureName = "aibridge"
FeatureBoundary FeatureName = "boundary"
FeatureServiceAccounts FeatureName = "service_accounts"
FeatureAIGovernanceUserLimit FeatureName = "ai_governance_user_limit"
)
@@ -227,6 +228,7 @@ var (
FeatureWorkspaceExternalAgent,
FeatureAIBridge,
FeatureBoundary,
FeatureServiceAccounts,
FeatureAIGovernanceUserLimit,
}
@@ -275,6 +277,7 @@ func (n FeatureName) AlwaysEnable() bool {
FeatureWorkspacePrebuilds: true,
FeatureWorkspaceExternalAgent: true,
FeatureBoundary: true,
FeatureServiceAccounts: true,
}[n]
}
@@ -282,7 +285,7 @@ func (n FeatureName) AlwaysEnable() bool {
func (n FeatureName) Enterprise() bool {
switch n {
// Add all features that should be excluded in the Enterprise feature set.
case FeatureMultipleOrganizations, FeatureCustomRoles:
case FeatureMultipleOrganizations, FeatureCustomRoles, FeatureServiceAccounts:
return false
default:
return true
@@ -3621,6 +3624,29 @@ Write out the current server config as YAML to stdout.`,
YAML: "acquireBatchSize",
Hidden: true, // Hidden because most operators should not need to modify this.
},
{
Name: "Chat: Pubsub Flush Interval",
Description: "The maximum time accepted chatd pubsub publishes wait before the batching loop schedules a flush.",
Flag: "chat-pubsub-flush-interval",
Env: "CODER_CHAT_PUBSUB_FLUSH_INTERVAL",
Value: &c.AI.Chat.PubsubFlushInterval,
Default: "50ms",
Group: &deploymentGroupChat,
YAML: "pubsubFlushInterval",
Annotations: serpent.Annotations{}.Mark(annotationFormatDuration, "true"),
Hidden: true,
},
{
Name: "Chat: Pubsub Queue Size",
Description: "How many chatd pubsub publishes can wait in memory for the dedicated sender path when PostgreSQL falls behind.",
Flag: "chat-pubsub-queue-size",
Env: "CODER_CHAT_PUBSUB_QUEUE_SIZE",
Value: &c.AI.Chat.PubsubQueueSize,
Default: "8192",
Group: &deploymentGroupChat,
YAML: "pubsubQueueSize",
Hidden: true,
},
// AI Bridge Options
{
Name: "AI Bridge Enabled",
@@ -4087,7 +4113,9 @@ type AIBridgeProxyConfig struct {
}
type ChatConfig struct {
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
PubsubFlushInterval serpent.Duration `json:"pubsub_flush_interval" typescript:",notnull"`
PubsubQueueSize serpent.Int64 `json:"pubsub_queue_size" typescript:",notnull"`
}
type AIConfig struct {
+41
View File
@@ -0,0 +1,41 @@
package codersdk
import (
"time"
"github.com/google/uuid"
)
// UserSecret represents a user secret's metadata. The secret value
// is never included in API responses.
type UserSecret struct {
ID uuid.UUID `json:"id" format:"uuid"`
Name string `json:"name"`
Description string `json:"description"`
EnvName string `json:"env_name"`
FilePath string `json:"file_path"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
}
// CreateUserSecretRequest is the payload for creating a new user
// secret. Name and Value are required. All other fields are optional
// and default to empty string.
type CreateUserSecretRequest struct {
Name string `json:"name"`
Value string `json:"value"`
Description string `json:"description,omitempty"`
EnvName string `json:"env_name,omitempty"`
FilePath string `json:"file_path,omitempty"`
}
// UpdateUserSecretRequest is the payload for partially updating a
// user secret. At least one field must be non-nil. Pointer fields
// distinguish "not sent" (nil) from "set to empty string" (pointer
// to empty string).
type UpdateUserSecretRequest struct {
Value *string `json:"value,omitempty"`
Description *string `json:"description,omitempty"`
EnvName *string `json:"env_name,omitempty"`
FilePath *string `json:"file_path,omitempty"`
}
+191
View File
@@ -0,0 +1,191 @@
package codersdk
import (
"regexp"
"strings"
"golang.org/x/xerrors"
)
// UserSecretEnvValidationOptions controls deployment-aware behavior
// in environment variable name validation.
type UserSecretEnvValidationOptions struct {
// AIGatewayEnabled indicates that the deployment has AI Gateway
// configured. When true, AI Gateway environment variables
// (OPENAI_API_KEY, etc.) are reserved to prevent conflicts.
AIGatewayEnabled bool
}
var (
// posixEnvNameRegex matches valid POSIX environment variable names:
// must start with a letter or underscore, followed by letters,
// digits, or underscores.
posixEnvNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// reservedEnvNames are system environment variables that must not
// be overridden by user secrets. This list is intentionally
// aggressive because it is easier to remove entries later than
// to add them after users have already created conflicting
// secrets.
reservedEnvNames = map[string]struct{}{
// Core POSIX/login variables. Overriding these breaks
// basic shell and session behavior.
"PATH": {},
"HOME": {},
"SHELL": {},
"USER": {},
"LOGNAME": {},
"PWD": {},
"OLDPWD": {},
// Locale and terminal. Agents and IDEs depend on these
// being set correctly by the system.
"LANG": {},
"TERM": {},
// Shell behavior. Overriding these can silently break
// word splitting, directory resolution, and script
// execution in every shell session and agent script.
"IFS": {},
"CDPATH": {},
// Shell startup files. ENV is sourced by POSIX sh for
// interactive shells; BASH_ENV is sourced by bash for
// every non-interactive invocation (scripts, subshells).
// Allowing users to set these would inject arbitrary
// code into every shell and script in the workspace.
"ENV": {},
"BASH_ENV": {},
// Temp directories. Overriding these is a security risk
// (symlink attacks, world-readable paths).
"TMPDIR": {},
"TMP": {},
"TEMP": {},
// Host identity.
"HOSTNAME": {},
// SSH session variables. The Coder agent sets
// SSH_AUTH_SOCK in agentssh.go; the others are set by
// sshd and should never be faked.
"SSH_AUTH_SOCK": {},
"SSH_CLIENT": {},
"SSH_CONNECTION": {},
"SSH_TTY": {},
// Editor/pager. The Coder agent sets these so that git
// operations inside workspaces work non-interactively.
"EDITOR": {},
"VISUAL": {},
"PAGER": {},
// IDE integration. The agent sets these for code-server
// and VS Code Remote proxying.
"VSCODE_PROXY_URI": {},
"CS_DISABLE_GETTING_STARTED_OVERRIDE": {},
// XDG base directories. Overriding these redirects
// config, cache, and runtime data for every tool in the
// workspace.
"XDG_RUNTIME_DIR": {},
"XDG_CONFIG_HOME": {},
"XDG_DATA_HOME": {},
"XDG_CACHE_HOME": {},
"XDG_STATE_HOME": {},
// OIDC token. The Coder agent injects a short-lived
// OIDC token for cloud auth flows (e.g. GCP workload
// identity). Overriding it could break provisioner and
// agent authentication.
"OIDC_TOKEN": {},
}
// aiGatewayReservedEnvNames are reserved only when AI Gateway
// is enabled on the deployment. When AI Gateway is disabled,
// users may legitimately want to inject their own API keys
// via secrets.
aiGatewayReservedEnvNames = map[string]struct{}{
"OPENAI_API_KEY": {},
"OPENAI_BASE_URL": {},
"ANTHROPIC_AUTH_TOKEN": {},
"ANTHROPIC_BASE_URL": {},
}
// reservedEnvPrefixes are namespace prefixes where every
// variable in the family is reserved. Checked after the
// exact-name map. The CODER / CODER_* namespace is handled
// separately with its own error message (see below).
reservedEnvPrefixes = []string{
// The Coder agent sets GIT_SSH_COMMAND, GIT_ASKPASS,
// GIT_AUTHOR_*, GIT_COMMITTER_*, and several others.
// Blocking the entire GIT_* namespace avoids an arms
// race with new git env vars.
"GIT_",
// Locale variables. LC_ALL, LC_CTYPE, LC_MESSAGES,
// etc. control character encoding, sorting, and
// formatting. Overriding them can break text
// processing in agents and IDEs.
"LC_",
// Dynamic linker variables. Allowing users to set
// these would let a secret inject arbitrary shared
// libraries into every process in the workspace.
"LD_",
"DYLD_",
}
)
// UserSecretEnvNameValid validates an environment variable name for
// a user secret. Empty string is allowed (means no env injection).
// The opts parameter controls deployment-aware checks such as AI
// bridge variable reservation.
func UserSecretEnvNameValid(s string, opts UserSecretEnvValidationOptions) error {
if s == "" {
return nil
}
if !posixEnvNameRegex.MatchString(s) {
return xerrors.New("must start with a letter or underscore, followed by letters, digits, or underscores")
}
upper := strings.ToUpper(s)
if _, ok := reservedEnvNames[upper]; ok {
return xerrors.Errorf("%s is a reserved environment variable name", upper)
}
if upper == "CODER" || strings.HasPrefix(upper, "CODER_") {
return xerrors.New("environment variable names starting with CODER_ are reserved for internal use")
}
for _, prefix := range reservedEnvPrefixes {
if strings.HasPrefix(upper, prefix) {
return xerrors.Errorf("environment variables starting with %s are reserved", prefix)
}
}
if opts.AIGatewayEnabled {
if _, ok := aiGatewayReservedEnvNames[upper]; ok {
return xerrors.Errorf("%s is reserved when AI Gateway is enabled", upper)
}
}
return nil
}
// UserSecretFilePathValid validates a file path for a user secret.
// Empty string is allowed (means no file injection). Non-empty paths
// must start with ~/ or /.
func UserSecretFilePathValid(s string) error {
if s == "" {
return nil
}
if strings.HasPrefix(s, "~/") || strings.HasPrefix(s, "/") {
return nil
}
return xerrors.New("file path must start with ~/ or /")
}
+183
View File
@@ -0,0 +1,183 @@
package codersdk_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/codersdk"
)
func TestUserSecretEnvNameValid(t *testing.T) {
t.Parallel()
// noAIGateway is the default for most tests — AI Gateway disabled.
noAIGateway := codersdk.UserSecretEnvValidationOptions{}
withAIGateway := codersdk.UserSecretEnvValidationOptions{AIGatewayEnabled: true}
tests := []struct {
name string
input string
opts codersdk.UserSecretEnvValidationOptions
wantErr bool
errMsg string
}{
// Valid names.
{name: "SimpleUpper", input: "GITHUB_TOKEN", opts: noAIGateway},
{name: "SimpleLower", input: "github_token", opts: noAIGateway},
{name: "StartsWithUnderscore", input: "_FOO", opts: noAIGateway},
{name: "SingleChar", input: "A", opts: noAIGateway},
{name: "WithDigits", input: "A1B2", opts: noAIGateway},
{name: "Empty", input: "", opts: noAIGateway},
// Invalid POSIX names.
{name: "StartsWithDigit", input: "1FOO", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
{name: "ContainsHyphen", input: "FOO-BAR", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
{name: "ContainsDot", input: "FOO.BAR", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
{name: "ContainsSpace", input: "FOO BAR", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
// Reserved system names — core POSIX/login.
{name: "ReservedPATH", input: "PATH", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedHOME", input: "HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedSHELL", input: "SHELL", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedUSER", input: "USER", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedLOGNAME", input: "LOGNAME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedPWD", input: "PWD", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedOLDPWD", input: "OLDPWD", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — locale/terminal.
{name: "ReservedLANG", input: "LANG", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedTERM", input: "TERM", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — shell behavior.
{name: "ReservedIFS", input: "IFS", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedCDPATH", input: "CDPATH", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — shell startup files.
{name: "ReservedENV", input: "ENV", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedBASH_ENV", input: "BASH_ENV", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — temp directories.
{name: "ReservedTMPDIR", input: "TMPDIR", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedTMP", input: "TMP", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedTEMP", input: "TEMP", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — host identity.
{name: "ReservedHOSTNAME", input: "HOSTNAME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — SSH.
{name: "ReservedSSH_AUTH_SOCK", input: "SSH_AUTH_SOCK", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedSSH_CLIENT", input: "SSH_CLIENT", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedSSH_CONNECTION", input: "SSH_CONNECTION", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedSSH_TTY", input: "SSH_TTY", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — editor/pager.
{name: "ReservedEDITOR", input: "EDITOR", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedVISUAL", input: "VISUAL", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedPAGER", input: "PAGER", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — IDE integration.
{name: "ReservedVSCODE_PROXY_URI", input: "VSCODE_PROXY_URI", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedCS_DISABLE", input: "CS_DISABLE_GETTING_STARTED_OVERRIDE", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — XDG.
{name: "ReservedXDG_RUNTIME_DIR", input: "XDG_RUNTIME_DIR", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedXDG_CONFIG_HOME", input: "XDG_CONFIG_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedXDG_DATA_HOME", input: "XDG_DATA_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedXDG_CACHE_HOME", input: "XDG_CACHE_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedXDG_STATE_HOME", input: "XDG_STATE_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — OIDC.
{name: "ReservedOIDC_TOKEN", input: "OIDC_TOKEN", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// AI Gateway vars — blocked when AI Gateway is enabled.
{name: "AIGateway/OPENAI_API_KEY/Enabled", input: "OPENAI_API_KEY", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
{name: "AIGateway/OPENAI_BASE_URL/Enabled", input: "OPENAI_BASE_URL", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
{name: "AIGateway/ANTHROPIC_AUTH_TOKEN/Enabled", input: "ANTHROPIC_AUTH_TOKEN", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
{name: "AIGateway/ANTHROPIC_BASE_URL/Enabled", input: "ANTHROPIC_BASE_URL", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
// AI Gateway vars — allowed when AI Gateway is disabled.
{name: "AIGateway/OPENAI_API_KEY/Disabled", input: "OPENAI_API_KEY", opts: noAIGateway},
{name: "AIGateway/OPENAI_BASE_URL/Disabled", input: "OPENAI_BASE_URL", opts: noAIGateway},
{name: "AIGateway/ANTHROPIC_AUTH_TOKEN/Disabled", input: "ANTHROPIC_AUTH_TOKEN", opts: noAIGateway},
{name: "AIGateway/ANTHROPIC_BASE_URL/Disabled", input: "ANTHROPIC_BASE_URL", opts: noAIGateway},
// Case insensitivity.
{name: "ReservedCaseInsensitive", input: "path", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// CODER_ prefix.
{name: "CoderExact", input: "CODER", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
{name: "CoderPrefix", input: "CODER_WORKSPACE_NAME", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
{name: "CoderAgentToken", input: "CODER_AGENT_TOKEN", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
{name: "CoderLowerCase", input: "coder_foo", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
// GIT_* prefix.
{name: "GitSSHCommand", input: "GIT_SSH_COMMAND", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
{name: "GitAskpass", input: "GIT_ASKPASS", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
{name: "GitAuthorName", input: "GIT_AUTHOR_NAME", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
{name: "GitLowerCase", input: "git_editor", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
// LC_* prefix (locale).
{name: "LcAll", input: "LC_ALL", opts: noAIGateway, wantErr: true, errMsg: "LC_"},
{name: "LcCtype", input: "LC_CTYPE", opts: noAIGateway, wantErr: true, errMsg: "LC_"},
// LD_* prefix (dynamic linker).
{name: "LdPreload", input: "LD_PRELOAD", opts: noAIGateway, wantErr: true, errMsg: "LD_"},
{name: "LdLibraryPath", input: "LD_LIBRARY_PATH", opts: noAIGateway, wantErr: true, errMsg: "LD_"},
// DYLD_* prefix (macOS dynamic linker).
{name: "DyldInsert", input: "DYLD_INSERT_LIBRARIES", opts: noAIGateway, wantErr: true, errMsg: "DYLD_"},
{name: "DyldLibraryPath", input: "DYLD_LIBRARY_PATH", opts: noAIGateway, wantErr: true, errMsg: "DYLD_"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := codersdk.UserSecretEnvNameValid(tt.input, tt.opts)
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestUserSecretFilePathValid(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
}{
// Valid paths.
{name: "TildePath", input: "~/foo"},
{name: "TildeSSH", input: "~/.ssh/id_rsa"},
{name: "AbsolutePath", input: "/home/coder/.ssh/id_rsa"},
{name: "RootPath", input: "/"},
{name: "Empty", input: ""},
// Invalid paths.
{name: "BareRelative", input: "foo/bar", wantErr: true},
{name: "DotRelative", input: ".ssh/id_rsa", wantErr: true},
{name: "JustFilename", input: "credentials", wantErr: true},
{name: "TildeNoSlash", input: "~foo", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := codersdk.UserSecretFilePathValid(tt.input)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), "must start with")
} else {
assert.NoError(t, err)
}
})
}
}
+56 -19
View File
@@ -211,33 +211,53 @@ Coder releases are initiated via
[`./scripts/release.sh`](https://github.com/coder/coder/blob/main/scripts/release.sh)
and automated via GitHub Actions. Specifically, the
[`release.yaml`](https://github.com/coder/coder/blob/main/.github/workflows/release.yaml)
workflow. They are created based on the current
[`main`](https://github.com/coder/coder/tree/main) branch.
workflow.
The release notes for a release are automatically generated from commit titles
and metadata from PRs that are merged into `main`.
Release notes are automatically generated from commit titles and PR metadata.
### Creating a release
### Release types
The creation of a release is initiated via
[`./scripts/release.sh`](https://github.com/coder/coder/blob/main/scripts/release.sh).
This script will show a preview of the release that will be created, and if you
choose to continue, create and push the tag which will trigger the creation of
the release via GitHub Actions.
| Type | Tag | Branch | Purpose |
|------------------------|---------------|---------------|-----------------------------------------|
| RC (release candidate) | `vX.Y.0-rc.W` | `main` | Ad-hoc pre-release for customer testing |
| Release | `vX.Y.0` | `release/X.Y` | First release of a minor version |
| Patch | `vX.Y.Z` | `release/X.Y` | Bug fixes and security patches |
See `./scripts/release.sh --help` for more information.
### Workflow
RC tags are created directly on `main`. The `release/X.Y` branch is only cut
when the release is ready. This avoids cherry-picking main's progress onto
a release branch between the first RC and the release.
```text
main: ──●──●──●──●──●──●──●──●──●──
↑ ↑ ↑
rc.0 rc.1 cut release/2.34, tag v2.34.0
\
release/2.34: ──●── v2.34.1 (patch)
```
1. **RC:** On `main`, run `./scripts/release.sh`. The tool suggests the next
RC version and tags it on `main`.
2. **Release:** When the RC is blessed, create `release/X.Y` from `main` (or
the specific RC commit). Switch to that branch and run
`./scripts/release.sh`, which suggests `vX.Y.0`.
3. **Patch:** Cherry-pick fixes onto `release/X.Y` and run
`./scripts/release.sh` from that branch.
The release tool warns if you try to tag a non-RC on `main` or an RC on a
release branch.
### Creating a release (via workflow dispatch)
Typically the workflow dispatch is only used to test (dry-run) a release,
meaning no actual release will take place. The workflow can be dispatched
manually from
[Actions: Release](https://github.com/coder/coder/actions/workflows/release.yaml).
Simply press "Run workflow" and choose dry-run.
If the
[`release.yaml`](https://github.com/coder/coder/actions/workflows/release.yaml)
workflow fails after the tag has been pushed, retry it from the GitHub Actions
UI: press "Run workflow", set "Use workflow from" to the tag (e.g.
`Tag: v2.34.0`), select the correct release channel, and do **not** select
dry-run.
If a release has failed after the tag has been created and pushed, it can be
retried by again, pressing "Run workflow", changing "Use workflow from" from
"Branch: main" to "Tag: vX.X.X" and not selecting dry-run.
To test the workflow without publishing, select dry-run.
### Commit messages
@@ -271,6 +291,23 @@ specification, however, it's still possible to merge PRs on GitHub with a badly
formatted title. Take care when merging single-commit PRs as GitHub may prefer
to use the original commit title instead of the PR title.
### Backporting fixes to release branches
When a merged PR on `main` should also ship in older releases, add the
`backport` label to the PR. The
[backport workflow](https://github.com/coder/coder/blob/main/.github/workflows/backport.yaml)
will automatically detect the latest three `release/*` branches,
cherry-pick the merge commit onto each one, and open PRs for
review.
The label can be added before or after the PR is merged. Each backport
PR reuses the original title (e.g.
`fix(site): correct button alignment (#12345)`) so the change is
meaningful in release notes.
If the cherry-pick encounters conflicts, the backport PR is still created
with instructions for manual resolution — no conflict markers are committed.
### Breaking changes
Breaking changes can be triggered in two ways:
+6
View File
@@ -150,6 +150,12 @@ deployment. They will always be available from the agent.
| `coder_derp_server_sent_pong_total` | counter | Total pongs sent. | |
| `coder_derp_server_unknown_frames_total` | counter | Total unknown frames received. | |
| `coder_derp_server_watchers` | gauge | Current watchers. | |
| `coder_pubsub_batch_delegate_fallbacks_total` | counter | The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage. | `channel_class` `reason` `stage` |
| `coder_pubsub_batch_flush_duration_seconds` | histogram | The time spent flushing one chatd batch to PostgreSQL. | `reason` |
| `coder_pubsub_batch_queue_depth` | gauge | The number of chatd notifications waiting in the batching queue. | |
| `coder_pubsub_batch_sender_reset_failures_total` | counter | The number of batched pubsub sender reset attempts that failed. | |
| `coder_pubsub_batch_sender_resets_total` | counter | The number of successful batched pubsub sender resets after flush failures. | |
| `coder_pubsub_batch_size` | histogram | The number of logical notifications sent in each chatd batch flush. | |
| `coder_pubsub_connected` | gauge | Whether we are connected (1) or not connected (0) to postgres | |
| `coder_pubsub_current_events` | gauge | The current number of pubsub event channels listened for | |
| `coder_pubsub_current_subscribers` | gauge | The current number of active pubsub subscribers | |
@@ -37,14 +37,11 @@ resource "docker_container" "workspace" {
resource "coder_agent" "main" {
arch = data.coder_provisioner.me.arch
os = "linux"
startup_script = <<EOF
startup_script = <<-EOF
#!/bin/sh
# Start Docker
sudo dockerd &
# ...
EOF
set -e
sudo service docker start
EOF
}
```
@@ -78,13 +75,10 @@ resource "coder_agent" "main" {
os = "linux"
arch = "amd64"
dir = "/home/coder"
startup_script = <<EOF
startup_script = <<-EOF
#!/bin/sh
# Start Docker
sudo dockerd &
# ...
set -e
sudo service docker start
EOF
}
+15 -8
View File
@@ -1,31 +1,38 @@
# Headless Authentication
Headless user accounts that cannot use the web UI to log in to Coder. This is
useful for creating accounts for automated systems, such as CI/CD pipelines or
for users who only consume Coder via another client/API.
> [!NOTE]
> Creating service accounts requires a [Premium license](https://coder.com/pricing).
You must have the User Admin role or above to create headless users.
Service accounts are headless user accounts that cannot use the web UI to log in
to Coder. This is useful for creating accounts for automated systems, such as
CI/CD pipelines or for users who only consume Coder via another client/API. Service accounts do not have passwords or associated email addresses.
## Create a headless user
You must have the User Admin role or above to create service accounts.
## Create a service account
<div class="tabs">
## CLI
Use the `--service-account` flag to create a dedicated service account:
```sh
coder users create \
--email="coder-bot@coder.com" \
--username="coder-bot" \
--login-type="none" \
--service-account
```
## UI
Navigate to the `Users` > `Create user` in the topbar
Navigate to **Deployment** > **Users** > **Create user**, then select
**Service account** as the login type.
![Create a user via the UI](../../images/admin/users/headless-user.png)
</div>
## Authenticate as a service account
To make API or CLI requests on behalf of the headless user, learn how to
[generate API tokens on behalf of a user](./sessions-tokens.md#generate-a-long-lived-api-token-on-behalf-of-another-user).
+13 -3
View File
@@ -180,6 +180,15 @@ configuration set by an administrator.
|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------|
| `web_search` | Searches the internet for up-to-date information. Available when web search is enabled for the configured Anthropic, OpenAI, or Google provider. |
### Workspace extension tools
These tools are conditionally available based on the workspace contents.
| Tool | What it does |
|-------------------|--------------------------------------------------------------------------------------------------------------------------------|
| `read_skill` | Reads the instructions for a workspace skill by name. Available when the workspace has skills discovered in `.agents/skills/`. |
| `read_skill_file` | Reads a supporting file from a skill's directory. |
## What runs where
Understanding the split between the control plane and the workspace is central
@@ -224,10 +233,11 @@ Because state lives in the database:
- The agent can resume work by targeting a new workspace and continuing from the
last git branch or checkpoint.
## Security implications
## Security posture
The control plane architecture provides several security advantages for AI
coding workflows.
The control plane architecture provides built-in security properties for AI
coding workflows. These are structural guarantees, not configuration options —
they hold by default for every agent session.
### No API keys in workspaces
+3 -6
View File
@@ -65,12 +65,9 @@ Once the server restarts with the experiment enabled:
1. Navigate to the **Agents** page in the Coder dashboard.
1. Open **Admin** settings and configure at least one LLM provider and model.
See [Models](./models.md) for detailed setup instructions.
1. Grant the **Coder Agents User** role to existing users who need to create
chats. New users receive the role automatically. For existing users, go to
**Admin** > **Users**, click the roles icon next to each user, and enable
**Coder Agents User**. See
[Grant Coder Agents User](./getting-started.md#step-3-grant-coder-agents-user)
for a bulk CLI option.
1. Grant the **Coder Agents User** role to users who need to create chats.
Go to **Admin** > **Users**, click the roles icon next to each user,
and enable **Coder Agents User**.
1. Developers can then start a new chat from the Agents page.
## Licensing and availability
+10 -22
View File
@@ -24,9 +24,8 @@ Before you begin, confirm the following:
for the agent to select when provisioning workspaces.
- **Admin access** to the Coder deployment for enabling the experiment and
configuring providers.
- **Coder Agents User role** is automatically assigned to new users when the
`agents` experiment is enabled. For existing users, owners can assign it from
**Admin** > **Users**. See
- **Coder Agents User role** assigned to each user who needs to interact with Coder Agents.
Owners can assign this from **Admin** > **Users**. See
[Grant Coder Agents User](#step-3-grant-coder-agents-user) below.
## Step 1: Enable the experiment
@@ -75,20 +74,14 @@ Detailed instructions for each provider and model option are in the
## Step 3: Grant Coder Agents User
The **Coder Agents User** role controls which users can interact with
Coder Agents.
The **Coder Agents User** role controls which users can interact with Coder Agents.
Members do not have Coder Agents User by default.
### New users
Owners always have full access and do not need the role. Repeat the following steps for each user who needs access.
When the `agents` experiment is enabled, new users are automatically
assigned the **Coder Agents User** role at account creation. No admin
action is required.
### Existing users
Users who were created before the experiment was enabled do not receive
the role automatically. Owners can assign it from the dashboard or in
bulk via the CLI.
> [!NOTE]
> Users who created conversations before this role was introduced are
> automatically granted the role during upgrade.
**Dashboard (individual):**
@@ -98,7 +91,8 @@ bulk via the CLI.
**CLI (bulk):**
To grant the role to all active users at once:
You can also grant the role via CLI. For example, to grant the role to
all active users at once:
```sh
coder users list -o json \
@@ -111,12 +105,6 @@ coder users list -o json \
done
```
Owners always have full access and do not need the role.
> [!NOTE]
> Users who created conversations before this role was introduced are
> automatically granted the role during upgrade.
## Step 4: Start your first Coder Agent
1. Go to the **Agents** page in the Coder dashboard.
+31 -23
View File
@@ -232,35 +232,43 @@ model. Developers select from enabled models when starting a chat.
The agent has access to a set of workspace tools that it uses to accomplish
tasks:
| Tool | Description |
|--------------------|---------------------------------------------------------|
| `list_templates` | Browse available workspace templates |
| `read_template` | Get template details and configurable parameters |
| `create_workspace` | Create a workspace from a template |
| `start_workspace` | Start a stopped workspace for the current chat |
| `propose_plan` | Present a Markdown plan file for user review |
| `read_file` | Read file contents from the workspace |
| `write_file` | Write a file to the workspace |
| `edit_files` | Perform search-and-replace edits across files |
| `execute` | Run shell commands in the workspace |
| `process_output` | Retrieve output from a background process |
| `process_list` | List all tracked processes in the workspace |
| `process_signal` | Send a signal (terminate/kill) to a tracked process |
| `spawn_agent` | Delegate a task to a sub-agent running in parallel |
| `wait_agent` | Wait for a sub-agent to complete and collect its result |
| `message_agent` | Send a follow-up message to a running sub-agent |
| `close_agent` | Stop a running sub-agent |
| `web_search` | Search the internet (provider-native, when enabled) |
| Tool | Description |
|----------------------------|--------------------------------------------------------------------------|
| `list_templates` | Browse available workspace templates |
| `read_template` | Get template details and configurable parameters |
| `create_workspace` | Create a workspace from a template |
| `start_workspace` | Start a stopped workspace for the current chat |
| `propose_plan` | Present a Markdown plan file for user review |
| `read_file` | Read file contents from the workspace |
| `write_file` | Write a file to the workspace |
| `edit_files` | Perform search-and-replace edits across files |
| `execute` | Run shell commands in the workspace |
| `process_output` | Retrieve output from a background process |
| `process_list` | List all tracked processes in the workspace |
| `process_signal` | Send a signal (terminate/kill) to a tracked process |
| `spawn_agent` | Delegate a task to a sub-agent running in parallel |
| `wait_agent` | Wait for a sub-agent to complete and collect its result |
| `message_agent` | Send a follow-up message to a running sub-agent |
| `close_agent` | Stop a running sub-agent |
| `spawn_computer_use_agent` | Spawn a sub-agent with desktop interaction (screenshot, mouse, keyboard) |
| `read_skill` | Read the instructions for a workspace skill by name |
| `read_skill_file` | Read a supporting file from a skill's directory |
| `web_search` | Search the internet (provider-native, when enabled) |
These tools connect to the workspace over the same secure connection used for
web terminals and IDE access. No additional ports or services are required in
the workspace.
Platform tools (`list_templates`, `read_template`, `create_workspace`,
`start_workspace`, `propose_plan`) and orchestration tools (`spawn_agent`)
are only available to root chats. Sub-agents do
not have access to these tools and cannot create workspaces or spawn further
sub-agents.
`start_workspace`, `propose_plan`) and orchestration tools (`spawn_agent`,
`wait_agent`, `message_agent`, `close_agent`, `spawn_computer_use_agent`)
are only available to root chats. Sub-agents do not have access to these
tools and cannot create workspaces or spawn further sub-agents.
`spawn_computer_use_agent` additionally requires an Anthropic provider and
the virtual desktop feature to be enabled by an administrator.
`read_skill` and `read_skill_file` are available when the workspace contains
skills in its `.agents/skills/` directory.
## Comparison to Coder Tasks
+87 -13
View File
@@ -1,10 +1,13 @@
# Models
Administrators configure LLM providers and models from the Coder dashboard.
These are deployment-wide settings — developers do not manage API keys or
provider configuration. They select from the set of models that an administrator
Providers, models, and API keys are deployment-wide settings managed by
platform teams. Developers select from the set of models that an administrator
has enabled.
Optionally, administrators can allow developers to supply their own API keys
for specific providers. See [User API keys](#user-api-keys-byok) below.
## Providers
Each LLM provider has a type, an API key, and an optional base URL override.
@@ -57,6 +60,38 @@ access to LLM providers. See
[Architecture](./architecture.md#no-api-keys-in-workspaces) for details
on this security model.
### Key policy
Each provider has three policy flags that control how API keys are sourced:
| Setting | Default | Description |
|-------------------------|---------|-----------------------------------------------------------------------------------------------------|
| Central API key | On | The provider uses a deployment-managed API key entered by an administrator. |
| Allow user API keys | Off | Developers may supply their own API key for this provider. |
| Central key as fallback | Off | When user keys are allowed, fall back to the central key if a developer has not set a personal key. |
At least one credential source must be enabled. These settings appear in the
provider configuration form under **Key policy**.
The interaction between these flags determines whether a provider is available
to a given developer:
| Central key | User keys allowed | Fallback | Developer has key | Result |
|-------------|-------------------|----------|-------------------|----------------------|
| On | Off | — | — | Uses central key |
| Off | On | — | Yes | Uses developer's key |
| Off | On | — | No | Unavailable |
| On | On | Off | Yes | Uses developer's key |
| On | On | Off | No | Unavailable |
| On | On | On | Yes | Uses developer's key |
| On | On | On | No | Uses central key |
When a developer's personal key is present, it always takes precedence over
the central key. When user keys are required and fallback is disabled,
the provider is unavailable to developers who have not saved a personal key —
even if a central key exists. This is intentional: it enforces that each
developer authenticates with their own credentials.
## Models
Each model belongs to a provider and has its own configuration for context limits,
@@ -132,11 +167,11 @@ fields appear dynamically in the admin UI when you select a provider.
#### OpenAI
| Option | Description |
|-----------------------|---------------------------------------------------------------------------------------------------|
| Reasoning Effort | How much effort the model spends reasoning (`none`, `minimal`, `low`, `medium`, `high`, `xhigh`). |
| Max Completion Tokens | Cap on completion tokens for reasoning models. |
| Parallel Tool Calls | Whether the model can call multiple tools at once. |
| Option | Description |
|-----------------------|-------------------------------------------------------------------------------------------|
| Reasoning Effort | How much effort the model spends reasoning (`minimal`, `low`, `medium`, `high`, `xhigh`). |
| Max Completion Tokens | Cap on completion tokens for reasoning models. |
| Parallel Tool Calls | Whether the model can call multiple tools at once. |
#### Google
@@ -147,10 +182,10 @@ fields appear dynamically in the admin UI when you select a provider.
#### OpenRouter
| Option | Description |
|-------------------|-------------------------------------------------------------------------------|
| Reasoning Enabled | Enable extended reasoning mode. |
| Reasoning Effort | Reasoning effort level (`none`, `minimal`, `low`, `medium`, `high`, `xhigh`). |
| Option | Description |
|-------------------|---------------------------------------------------|
| Reasoning Enabled | Enable extended reasoning mode. |
| Reasoning Effort | Reasoning effort level (`low`, `medium`, `high`). |
#### Vercel AI Gateway
@@ -176,10 +211,49 @@ The model selector uses the following precedence to pre-select a model:
1. **Admin-designated default** — the model marked with the star icon.
1. **First available model** — if no default is set and no history exists.
Developers cannot add their own providers, models, or API keys. If no models
are configured, the chat interface displays a message directing developers to
Developers cannot add their own providers or models. If no models are
configured, the chat interface displays a message directing developers to
contact an administrator.
## User API keys (BYOK)
When an administrator enables **Allow user API keys** on a provider,
developers can supply their own API key from the Agents settings page.
### Managing personal API keys
1. Navigate to the **Agents** page in the Coder dashboard.
1. Open **Settings** and select the **API Keys** tab.
1. Each provider that allows user keys is listed with a status indicator:
- **Key saved** — your personal key is active and will be used for requests.
- **Using shared key** — no personal key set, but the central deployment
key is available as a fallback.
- **No key** — you must add a personal key before you can use this provider.
1. Enter your API key and click **Save**.
Personal API keys are encrypted at rest using the same database encryption
as deployment-managed keys. The dashboard never displays a saved key — only
whether one is set.
### How key selection works
When you start a chat, the control plane resolves which API key to use for
each provider:
1. If you have a personal key for the provider, it is used.
1. If you do not have a personal key and central key fallback is enabled,
the deployment-managed key is used.
1. If you do not have a personal key and fallback is disabled, the provider
is unavailable to you. Models from that provider will not appear in the
model selector.
### Removing a personal key
Click **Remove** on the provider card in the API Keys settings tab. If
central key fallback is enabled, subsequent requests will use the shared
deployment key. If fallback is disabled, the provider becomes unavailable
until you add a new personal key.
## Using an LLM proxy
Organizations that route LLM traffic through a centralized proxy — such as
@@ -0,0 +1,39 @@
# Conversation Data Retention
Coder Agents automatically cleans up old conversation data to manage database
growth. Archived conversations and their associated files are periodically
purged based on a configurable retention period.
## How it works
A background process runs approximately every 10 minutes to remove expired
conversation data. Only archived conversations are eligible for deletion —
active (non-archived) conversations are never purged.
When an archived conversation exceeds the retention period, it is deleted along
with its messages, diff statuses, and queued messages via cascade. Orphaned
files (not referenced by any active or recently-archived conversation) are also
deleted. Both operations run in batches of 1,000 rows per cycle.
## Configuration
Navigate to the **Agents** page, open **Settings**, and select the **Behavior**
tab to configure the conversation retention period. The default is 30 days. Use the toggle to
disable retention entirely.
The retention period is stored as the `agents_chat_retention_days` key in the
`site_configs` table and can also be managed via the API at
`/api/experimental/chats/config/retention-days`.
## What gets deleted
| Data | Condition | Cascade |
|------------------------|------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
| Archived conversations | Archived longer than retention period | Messages, diff statuses, queued messages deleted via CASCADE. |
| Conversation files | Older than retention period AND not referenced by any active or recently-archived conversation | — |
## Unarchive safety
If a user unarchives a conversation whose files were purged, stale file
references are automatically cleaned up by FK cascades. The conversation
remains usable but previously attached files are no longer available.
@@ -11,11 +11,12 @@ This means:
- **All agent configuration is admin-level.** Providers, models, system prompts,
and tool permissions are set by platform teams from the control plane. These
are not user preferences — they are deployment-wide policies.
- **Developers never need to configure anything.** A developer just describes
the work they want done. They do not need to pick a provider, enter an API
key, or write a system prompt — the platform team has already set all of
that up. The goal is not to restrict developers, but to make configuration
unnecessary for a great experience.
- **Developers never need to configure anything by default.** A developer just
describes the work they want done. They do not need to pick a provider or
write a system prompt — the platform team has already set all of that up.
When a platform team enables user API keys for a provider, developers may
optionally supply their own key — but this is an opt-in policy decision, not
a requirement.
- **Enforcement, not defaults.** Settings configured by administrators are
enforced server-side. Developers cannot override them. This is a deliberate
distinction — a setting that a user can change is a preference, not a policy.
@@ -36,8 +37,12 @@ self-hosted models), and per-model parameters like context limits, thinking
budgets, and reasoning effort.
Developers select from the set of models an administrator has enabled. They
cannot add their own providers, supply their own API keys, or access models that
have not been explicitly configured.
cannot add their own providers or access models that have not been explicitly
configured.
When an administrator enables user API keys on a provider, developers can
supply their own key from the Agents settings page. See
[User API keys (BYOK)](../models.md#user-api-keys-byok) for details.
See [Models](../models.md) for setup instructions.
@@ -84,6 +89,30 @@ opt-out, or opt-in for each chat.
See [MCP Servers](./mcp-servers.md) for configuration details.
### Virtual desktop
Administrators can enable a virtual desktop within agent workspaces.
When enabled, agents can use `spawn_computer_use_agent` to interact with a
desktop environment using screenshots, mouse, and keyboard input.
This setting is available under **Agents** > **Settings** > **Behavior**.
It requires:
- The [portabledesktop](https://registry.coder.com/modules/coder/portabledesktop)
module to be installed in the workspace template.
- An Anthropic provider to be configured (computer use is an Anthropic
capability).
### Workspace autostop fallback
Administrators can set a default autostop timer for agent-created workspaces
that do not define one in their template. Template-defined autostop rules always
take precedence. Active conversations extend the stop time automatically.
This setting is available under **Agents** > **Settings** > **Behavior**.
The maximum configurable value is 30 days. When disabled, workspaces follow
their template's autostop rules (or none, if the template does not define any).
### Usage limits and analytics
Administrators can set spend limits to cap LLM usage per user within a rolling
@@ -93,10 +122,19 @@ breakdowns.
See [Usage & Analytics](./usage-insights.md) for details.
### Data retention
Administrators can configure a retention period for archived conversations.
When enabled, archived conversations and orphaned files older than the
retention period are automatically purged. The default is 30 days.
This setting is available under **Agents** > **Settings** > **Behavior**.
See [Data Retention](./chat-retention.md) for details.
## Where we are headed
The controls above cover providers, models, system prompts, templates, MCP
servers, and usage limits. We are continuing to invest in platform controls
servers, usage limits, and data retention. We are continuing to invest in platform controls
based on what we hear from customers deploying agents in regulated and
enterprise environments.
@@ -83,3 +83,8 @@ Select a user to see:
bar shows current spend relative to the limit.
- **Per-model breakdown** — table of costs and token usage by model.
- **Per-chat breakdown** — table of costs and token usage by chat session.
> [!NOTE]
> Automatic title generation uses lightweight models, such as Claude Haiku or GPT-4o
> Mini. Its token usage is not counted towards usage limits or shown in usage
> summaries.
+7
View File
@@ -10,6 +10,13 @@ We provide an example Grafana dashboard that you can import as a starting point
These logs and metrics can be used to determine usage patterns, track costs, and evaluate tooling adoption.
## Structured Logging
AI Bridge can emit structured logs for every interception event to your
existing log pipeline. This is useful for exporting data to external SIEM or
observability platforms. See [Structured Logging](./setup.md#structured-logging)
in the setup guide for configuration and a full list of record types.
## Exporting Data
AI Bridge interception data can be exported for external analysis, compliance reporting, or integration with log aggregation systems.
+11 -1
View File
@@ -150,4 +150,14 @@ ingestion, set `--log-json` to a file path or `/dev/stderr` so that records are
emitted as JSON.
Filter for AI Bridge records in your logging pipeline by matching on the
`"interception log"` message.
`"interception log"` message. Each log line includes a `record_type` field that
indicates the kind of event captured:
| `record_type` | Description | Key fields |
|----------------------|-----------------------------------------|--------------------------------------------------------------------------------|
| `interception_start` | A new intercepted request begins. | `interception_id`, `initiator_id`, `provider`, `model`, `client`, `started_at` |
| `interception_end` | An intercepted request completes. | `interception_id`, `ended_at` |
| `token_usage` | Token consumption for a response. | `interception_id`, `input_tokens`, `output_tokens`, `created_at` |
| `prompt_usage` | The last user prompt in a request. | `interception_id`, `prompt`, `created_at` |
| `tool_usage` | A tool/function call made by the model. | `interception_id`, `tool`, `input`, `server_url`, `injected`, `created_at` |
| `model_thought` | Model reasoning or thinking content. | `interception_id`, `content`, `created_at` |
+3 -3
View File
@@ -83,9 +83,9 @@ pages.
| [2.26](https://coder.com/changelog/coder-2-26) | September 03, 2025 | Not Supported | [v2.26.6](https://github.com/coder/coder/releases/tag/v2.26.6) |
| [2.27](https://coder.com/changelog/coder-2-27) | October 02, 2025 | Not Supported | [v2.27.11](https://github.com/coder/coder/releases/tag/v2.27.11) |
| [2.28](https://coder.com/changelog/coder-2-28) | November 04, 2025 | Not Supported | [v2.28.11](https://github.com/coder/coder/releases/tag/v2.28.11) |
| [2.29](https://coder.com/changelog/coder-2-29) | December 02, 2025 | Security Support + ESR | [v2.29.8](https://github.com/coder/coder/releases/tag/v2.29.8) |
| [2.30](https://coder.com/changelog/coder-2-30) | February 03, 2026 | Stable | [v2.30.3](https://github.com/coder/coder/releases/tag/v2.30.3) |
| [2.31](https://coder.com/changelog/coder-2-31) | February 23, 2026 | Mainline | [v2.31.5](https://github.com/coder/coder/releases/tag/v2.31.5) |
| [2.29](https://coder.com/changelog/coder-2-29) | December 02, 2025 | Extended Support Release | [v2.29.9](https://github.com/coder/coder/releases/tag/v2.29.9) |
| [2.30](https://coder.com/changelog/coder-2-30) | February 03, 2026 | Security Support | [v2.30.6](https://github.com/coder/coder/releases/tag/v2.30.6) |
| [2.31](https://coder.com/changelog/coder-2-31) | February 23, 2026 | Stable | [v2.31.7](https://github.com/coder/coder/releases/tag/v2.31.7) |
| 2.32 | | Not Released | N/A |
<!-- RELEASE_CALENDAR_END -->
+8 -1
View File
@@ -495,7 +495,8 @@
{
"title": "Headless Authentication",
"description": "Create and manage headless service accounts for automated systems and API integrations",
"path": "./admin/users/headless-auth.md"
"path": "./admin/users/headless-auth.md",
"state": ["premium"]
},
{
"title": "Groups \u0026 Roles",
@@ -1249,6 +1250,12 @@
"description": "Spend limits and cost tracking for Coder Agents",
"path": "./ai-coder/agents/platform-controls/usage-insights.md",
"state": ["early access"]
},
{
"title": "Data Retention",
"description": "Automatic cleanup of old conversation data",
"path": "./ai-coder/agents/platform-controls/chat-retention.md",
"state": ["early access"]
}
]
},
+28
View File
@@ -2025,6 +2025,20 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|----------------------|---------|----------|--------------|-------------|
| `acquire_batch_size` | integer | false | | |
## codersdk.ChatRetentionDaysResponse
```json
{
"retention_days": 0
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|------------------|---------|----------|--------------|-------------|
| `retention_days` | integer | false | | |
## codersdk.ConnectionLatency
```json
@@ -10299,6 +10313,20 @@ Restarts will only happen on weekdays in this list on weeks which line up with W
| `logo_url` | string | false | | |
| `service_banner` | [codersdk.BannerConfig](#codersdkbannerconfig) | false | | Deprecated: ServiceBanner has been replaced by AnnouncementBanners. |
## codersdk.UpdateChatRetentionDaysRequest
```json
{
"retention_days": 0
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|------------------|---------|----------|--------------|-------------|
| `retention_days` | integer | false | | |
## codersdk.UpdateCheckResponse
```json

Some files were not shown because too many files have changed in this diff Show More