Compare commits

..

9 Commits

Author SHA1 Message Date
Kyle Carberry 459ea75c57 feat(coderd/automations): implement active mode — chat creation and rate limiting
Adds the Fire() executor that handles the full automation trigger
lifecycle for both webhook and cron paths:

- Rate limiting via CountAutomationChatCreatesInWindow and
  CountAutomationMessagesInWindow (both queries already existed
  but were never called)
- Label-based chat lookup to continue existing conversations
- New chat creation through chatd.Server via ChatCreator interface
- Event recording with matched_chat_id / created_chat_id
- Preview mode still only logs events without acting

Changes:
- automations/executor.go: shared Fire() function with full flow
- automations/chatadapter.go: ChatdAdapter bridges ChatCreator to
  chatd.Server
- exp_automations.go: webhook handler now calls Fire() instead of
  stub event insertion
- cronscheduler: accepts ChatCreator, delegates to Fire()
- coderd.go: ChatDaemon() public getter for chatd.Server
- GetActiveCronTriggers query expanded with rate limit and model
  fields
2026-03-26 14:01:20 +00:00
Kyle Carberry c04f4f6abf feat(coderd/automations): add cron trigger scheduler
Adds a background scheduler that evaluates cron-based automation
triggers every minute. Follows the dbpurge pattern with quartz.Clock
for testability and advisory locking for single-replica execution.

Changes:
- Migration 000453: adds last_triggered_at to automation_triggers
- SQL queries: GetActiveCronTriggers (JOIN with automations),
  UpdateAutomationTriggerLastTriggeredAt
- cron.Standard(): full 5-field cron parser without Weekly/Daily
  restrictions
- cronscheduler package: background goroutine, advisory lock,
  schedule evaluation, event creation, preview/active mode
- Handler validation: cron_schedule validated on trigger creation
- Wired into cli/server.go alongside dbpurge and autobuild
- 20 new tests (13 cron parser + 7 scheduler with quartz mock)
2026-03-26 13:45:05 +00:00
Kyle Carberry b117acfaaf fix(coderd): round-3 review fixes — formatting, validation, cleanup
H1: Regenerate typesGenerated.ts to match SDK changes
  (organization_id, webhook_secret, TestAutomationRequest).

H2: Fix gofmt alignment in AutomationTrigger struct.

M1: Fix safePayload data loss — truncate inner body before
  wrapping in JSON envelope, not after. Prevents complete data
  loss when non-JSON payloads are close to the 64KB limit.

M2: Validate filter (must be JSON object) and label_paths
  (must be map[string]string) in postAutomationTrigger. Prevents
  silent webhook filtering from malformed trigger config.

M3: Replace inline anonymous struct in testAutomation with
  codersdk.TestAutomationRequest to prevent type drift.

M4: Remove unused accessURL parameter from db2sdk.Automation()
  and all 4 call sites.

M5: Add format:"uuid" tags to MCPServerIDs on Automation,
  CreateAutomationRequest, and UpdateAutomationRequest.

L2: Catch FK violations in postAutomation — return 400 with
  descriptive message instead of generic 500.
2026-03-25 23:20:27 +00:00
Kyle Carberry 8d6143bbef fix(coderd): require explicit org, gate webhook, expose secret, handle non-JSON
M1: Add experiment gate in postAutomationWebhook — checks
api.Experiments.Enabled(ExperimentAgents) and returns 200 early
if disabled. Done in-handler (not middleware) to preserve the
always-200 contract.

M2: Replace nondeterministic orgs[0] selection with explicit
OrganizationID field on CreateAutomationRequest. Returns 400
if not provided. Follows the established pattern from chats.

M7: Add WebhookSecret field (omitempty) to AutomationTrigger
SDK type. Populated only in postAutomationTrigger and
regenerateAutomationTriggerSecret responses so the user can
configure their webhook provider. Never returned from list/get.

L2: Add safePayload() that wraps non-JSON webhook bodies in a
valid JSON envelope before storing. Preserves the audit trail
when webhook providers send form-encoded or XML payloads.
2026-03-25 21:12:03 +00:00
Kyle Carberry c6f384cc94 fix(coderd): address round-2 review issues in automations
High fixes:
- MatchFilter: use reflect.DeepEqual instead of != to prevent
  runtime panics on array/object filter values. Add test cases
  for booleans, arrays, and nested objects.
- postAutomation/patchAutomation: handle unique constraint
  violations with 409 Conflict instead of leaking raw Postgres
  errors as 500s. Uses database.IsUniqueViolation pattern.

Medium fixes:
- regenerateAutomationTriggerSecret: reject non-webhook triggers
  with 400 instead of silently setting an unused secret.
- coderd.go: fix mangled indentation on automations, chats, webhook,
  and deployment route blocks. Split r.Use(apiKeyMiddleware) onto
  its own line in /deployment route.
- Add session_test.go with 8 test cases for ResolveLabels covering
  nil/empty input, path extraction, missing paths, and type coercion.

Low fixes:
- testAutomation: return 400 on malformed label_paths instead of
  silently swallowing the parse error.
- Validate rate limit fields (1-1000 range) in both create and
  update handlers to prevent DB check constraint 500s.
2026-03-25 20:36:56 +00:00
Kyle Carberry 17e73f1711 fix(coderd): address critical and high review issues in automations
Critical fixes:
- postAutomationWebhook: parse trigger UUID manually to always return
  200 (never leak 400 to webhook sources)
- deleteAutomationTrigger/regenerateAutomationTriggerSecret: verify
  trigger.AutomationID matches automation from middleware (prevents
  cross-user trigger manipulation)
- truncatePayload: return valid JSON stub instead of byte-slicing
  (which produced broken JSON at 64KB boundary)
- InsertAutomationEvent: scope RBAC check to specific automation
  instead of bare resource type
- Add webhook_secret_key_id to trigger insert/update SQL queries
  (enables dbcrypt key tracking)

High fixes:
- Add index on chats.automation_id for ON DELETE SET NULL performance
- Add unique constraint on automations(owner_id, org_id, name)
- Add index on automation_events(received_at) for purge query
- patchAutomation: validate status against allowed values, reject
  empty name
- postAutomationWebhook: verify trigger.Type is 'webhook' before
  processing
- DeleteAutomationTriggerByID: use ActionUpdate (not ActionDelete)
  on parent automation, consistent with child-entity RBAC pattern
- TestAutomation SDK client: send proper struct with payload/filter/
  label_paths instead of raw JSON
- Remove dead UpdateAutomationTriggerRequest type
- postAutomationTrigger: validate trigger type against known values
- CountAutomationMessagesInWindow: count both 'created' and
  'continued' events toward message rate limit
- Scope CountAutomationChatCreatesInWindow and
  CountAutomationMessagesInWindow RBAC to specific automation
2026-03-25 20:17:36 +00:00
Kyle Carberry 39a5af04bf refactor: extract automation triggers into separate table
Trigger-specific fields (webhook_secret, cron_schedule, filter,
label_paths) move from the automations table into a new
automation_triggers table. An automation can have multiple triggers,
each with its own type (webhook or cron), secret, and filter config.

Each webhook trigger gets its own URL and HMAC secret, so external
systems configure per-trigger webhooks that can be independently
revoked.

The event log table is renamed from automation_webhook_events to
automation_events (events can now come from cron triggers too) and
gains a trigger_id FK to trace which trigger fired.

New CRUD endpoints for triggers nested under
/api/experimental/automations/{automation}/triggers. The webhook
ingestion endpoint moves to /api/v2/automations/triggers/{trigger_id}/webhook.
2026-03-25 19:46:04 +00:00
Kyle Carberry 13f77b2b27 refactor: address review feedback on automations
- Move convertAutomation/convertWebhookEvent to db2sdk package
- Rename session_labels → label_paths (clearer: maps label keys to gjson paths)
- Rename system_prompt → instructions (sent as user message, not system prompt)
- Remove workspace_id column (chat creation handles workspace selection)
- Add cron_schedule column for scheduled automations (nullable, mutually
  exclusive with webhook_secret in v1)
- Make webhook_secret nullable (cron automations don't need it)
2026-03-25 19:22:18 +00:00
Kyle Carberry 13ab3ad058 feat: add automations backend — webhooks to chat bridge
Adds the backend infrastructure for automations, a user-owned resource
that bridges external webhooks to Coder chats. When a webhook arrives,
the system verifies the HMAC-SHA256 signature, evaluates a gjson-based
filter, resolves a chat session via labels, and logs the event.

Database: migration 000452 adds automations table, automation_webhook_events
table, and chats.automation_id FK. RBAC: new "automation" resource type
with CRUD actions. Routes: authenticated CRUD under /api/experimental/automations
(experiment-gated), unauthenticated webhook at /api/v2/automations/{id}/webhook.

Active-mode chat creation through chatd.Server is deferred to a follow-up.
Audit logging requires a resource_type enum addition (separate migration).
2026-03-25 18:40:44 +00:00
1617 changed files with 30097 additions and 82147 deletions
+4 -6
View File
@@ -111,8 +111,8 @@ Tier 2 file filters:
- **Modernization Reviewer**: one instance per language present in the diff. Filter by extension:
- Go: `*.go` — reference `.claude/docs/GO.md` before reviewing.
- TypeScript: `*.ts` `*.tsx`: reference `.agents/skills/deep-review/references/typescript.md` before reviewing.
- React: `*.tsx` `*.jsx`: reference `.agents/skills/deep-review/references/react.md` before reviewing.
- TypeScript: `*.ts` `*.tsx`
- React: `*.tsx` `*.jsx`
`.tsx` files match both TypeScript and React filters. Spawn both instances when the diff contains `.tsx` changes — TS covers language-level patterns; React covers component and hooks patterns. Before spawning, verify each instance's filter produces a non-empty diff. Skip instances whose filtered diff is empty.
@@ -155,11 +155,9 @@ File scope: {filter from step 2}.
Output file: {REVIEW_DIR}/{role-name}.md
```
For Modernization Reviewer instances, add the language reference after the methodology line:
For the Modernization Reviewer (Go), add after the methodology line:
- **Go:** `Read .claude/docs/GO.md as your Go language reference before reviewing.`
- **TypeScript:** `Read .agents/skills/deep-review/references/typescript.md as your TypeScript language reference before reviewing.`
- **React:** `Read .agents/skills/deep-review/references/react.md as your React language reference before reviewing.`
> Read `.claude/docs/GO.md` as your Go language reference before reviewing.
For re-reviews, append to both Tier 1 and Tier 2 prompts:
@@ -1,305 +0,0 @@
# Modern React (1819.2) + Compiler 1.0 — Reference
Reference for writing idiomatic React. Covers what changed, what it replaced, and what to reach for. Includes React Compiler patterns — what the compiler handles automatically, what it changes semantically, and how to verify its behavior empirically. Scope: client-side SPA patterns only. Server Components, `use server`, and `use client` directives are framework-specific and omitted. Check the project's React version and compiler config before reaching for newer APIs.
## How modern React thinks differently
**Concurrent rendering** (18): React can now pause, interrupt, and resume renders. This is the foundation everything else builds on. Most existing code "just works," but components that produce side effects during render (mutations, subscriptions, network calls in the render body) are unsafe and will misbehave. Concurrent features are opt-in — they only activate when you use a concurrent API like `startTransition` or `useDeferredValue`.
**Urgent vs. non-urgent updates** (18): The `startTransition` / `useTransition` API introduces a formal split between updates that must feel immediate (typing, clicking) and updates that can be interrupted (filtering a large list, navigating to a new screen). Non-urgent updates yield to urgent ones mid-render. Use this instead of `setTimeout` or manual debounce when you want the UI to stay responsive during expensive re-renders.
**Actions** (19): Async functions passed to `startTransition` are called "Actions." They automatically manage pending state, error handling, and optimistic updates as a unit. The `useActionState` hook and `<form action={fn}>` prop are built on this. The pattern replaces the hand-rolled `isPending/setIsPending` + `try/catch` + `setError` boilerplate that was previously necessary for every data mutation.
**Automatic batching** (18): State updates are now batched everywhere — inside `setTimeout`, `Promise.then`, native event handlers, etc. Previously batching only happened inside React-managed event handlers. If you genuinely need a synchronous flush, use `flushSync`.
**Automatic memoization** (Compiler 1.0): React Compiler is a build-time Babel plugin that automatically inserts memoization into components and hooks. It replaces manual `useMemo`, `useCallback`, and `React.memo` — including conditional memoization and memoization after early returns, which manual APIs cannot express. The compiler only processes components and hooks, not standalone functions. It understands data flow and mutability through its own HIR (High-level Intermediate Representation), so it can memoize more granularly than a human would. Projects adopt it incrementally — typically via path-based Babel overrides or the `"use memo"` directive. Components that violate the Rules of React are silently skipped (no build error), so the automated lint tools that check compiler compatibility matter.
## Replace these patterns
The left column reflects patterns common before React 18/19. Write the right column instead. The "Since" column tells you the minimum React version required.
| Old pattern | Modern replacement | Since |
| ----------------------------------------------------------------- | ------------------------------------------------------------------------------ | ----- |
| `ReactDOM.render(<App />, el)` | `createRoot(el).render(<App />)` | 18 |
| `ReactDOM.hydrate(<App />, el)` | `hydrateRoot(el, <App />)` | 18 |
| `ReactDOM.unmountComponentAtNode(el)` | `root.unmount()` | 18 |
| `ReactDOM.findDOMNode(this)` | DOM ref: `const ref = useRef(); ref.current` | 18 |
| `<Context.Provider value={v}>` | `<Context value={v}>` | 19 |
| `React.forwardRef((props, ref) => ...)` | `function Comp({ ref, ...props }) { ... }` (ref as a regular prop) | 19 |
| String ref `ref="input"` in class components | Callback ref or `createRef()` | 19 |
| `Heading.propTypes = { ... }` | TypeScript / ES6 type annotations | 19 |
| `Component.defaultProps = { ... }` on function components | ES6 default parameters `({ text = 'Hi' })` | 19 |
| Legacy Context: `contextTypes` + `getChildContext` | `React.createContext()` + `contextType` | 19 |
| `import { act } from 'react-dom/test-utils'` | `import { act } from 'react'` | 19 |
| `import ShallowRenderer from 'react-test-renderer/shallow'` | `import ShallowRenderer from 'react-shallow-renderer'` | 19 |
| Manual `isPending` state around async calls | `const [isPending, startTransition] = useTransition()` | 18 |
| Manual optimistic state + revert logic | `useOptimistic(currentValue)` | 19 |
| `useEffect` to subscribe to external stores | `useSyncExternalStore(subscribe, getSnapshot)` | 18 |
| Hand-rolled unique ID (counter, random, index) | `useId()` — SSR-safe, hydration-safe | 18 |
| `useEffect` to inject `<title>` or `<meta>` / `react-helmet` | Render `<title>`, `<meta>`, `<link>` directly in components; React hoists them | 19 |
| `ReactDOM.useFormState(action, initial)` (Canary name) | `useActionState(action, initial)` | 19 |
| `useReducer<React.Reducer<State, Action>>(reducer)` | `useReducer(reducer)` — infers from the reducer function | 19 |
| `<div ref={current => (instance = current)} />` (implicit return) | `<div ref={current => { instance = current }} />` (explicit block body) | 19 |
| `useRef<T>()` with no argument | `useRef<T>(undefined)` or `useRef<T \| null>(null)` — argument is now required | 19 |
| `MutableRefObject<T>` type annotation | `RefObject<T>` — all refs are mutable now; `MutableRefObject` is deprecated | 19 |
| `React.createFactory('button')` | `<button />` JSX | 19 |
| `useMemo(() => expr, [deps])` in compiled components | `const val = expr;` — compiler memoizes automatically | C 1.0 |
| `useCallback(fn, [deps])` in compiled components | `const fn = () => { ... };` — compiler memoizes automatically | C 1.0 |
| `React.memo(Component)` in compiled components | Plain component — compiler skips re-render when props are unchanged | C 1.0 |
| `eslint-plugin-react-compiler` (standalone) | `eslint-plugin-react-hooks@latest` (compiler rules merged into recommended) | C 1.0 |
| `useRef` + `useLayoutEffect` for stable callbacks | `useEffectEvent(fn)` — compiler handles both, but `useEffectEvent` is clearer | 19.2 |
## New capabilities
These enable things that weren't practical before. Reach for them in the described situations.
| What | Since | When to use it |
| -------------------------------------------------------------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `useTransition()` / `startTransition()` | 18 | Mark a state update as non-urgent so React can interrupt it to handle clicks or keystrokes. The `isPending` boolean lets you show a loading indicator without blocking the UI. |
| `useDeferredValue(value, initialValue?)` | 18 / 19 | Defer re-rendering a slow subtree: pass the deferred value as a prop, wrap the expensive child in `memo`. Unlike debounce, uses no fixed timeout — renders as soon as the browser is idle. The `initialValue` arg (19) avoids a flash on first render. |
| `useId()` | 18 | Generate a stable, SSR-consistent ID for accessibility attributes (`htmlFor`, `aria-describedby`). Do not use for list keys. |
| `useSyncExternalStore(subscribe, getSnapshot, getServerSnapshot?)` | 18 | Subscribe to external (non-React) state stores safely under concurrent rendering. Preferred over `useEffect`-based subscriptions in libraries. |
| `useActionState(action, initialState)` | 19 | Manage an async mutation: returns `[state, wrappedAction, isPending]`. Handles pending, result, and error state as a unit. Replaces the manual `isPending` + `try/catch` + `setError` pattern. |
| `useOptimistic(currentValue)` | 19 | Show a speculative value while an async Action is in flight. Returns `[optimisticValue, setOptimistic]`. React automatically reverts to `currentValue` when the transition settles. |
| `use(promiseOrContext)` | 19 | Read a promise or Context value inside a component or custom hook. Unlike hooks, `use` can be called conditionally (after early returns). Promises must come from a cache — do not create them during render. |
| `useFormStatus()` (from `react-dom`) | 19 | Read `{ pending, data, method, action }` of the nearest parent `<form>` Action. Works across component boundaries without prop drilling — useful for submit buttons inside design-system components. |
| `useEffectEvent(fn)` | 19.2 | Extract a non-reactive callback from an effect. The function sees the latest props/state without being listed in deps, and is never stale. Replaces the `useRef`-and-mutate-in-layout-effect workaround for stable event-like callbacks. The compiler has built-in knowledge of this hook and correctly prunes its return value from effect dependency arrays. Both `useEffectEvent` and the old ref workaround compile cleanly; `useEffectEvent` is preferred for clarity. |
| `<Activity>` | 19.2 | Hide part of the UI while preserving its state and DOM. React deprioritizes updates to hidden content. Use via framework APIs for route prerendering or tab preservation — not a direct replacement for CSS `visibility`. |
| `captureOwnerStack()` | 19.1 | Dev-only API that returns a string showing which components are responsible for rendering the current component (owner stack, not call stack). Useful for custom error overlays. Returns `null` in production. |
| `<form action={fn}>` | 19 | Pass an async function as a form's `action` prop. React handles submission, pending state, and automatic form reset on success. Works with `useActionState` and `useFormStatus`. |
| Ref cleanup function | 19 | Return a cleanup function from a ref callback: `ref={el => { ...; return () => cleanup(); }}`. React calls it on unmount. Replaces the pattern of checking `el === null` in the callback. |
| `<link rel="stylesheet" precedence="default">` | 19 | Declare a stylesheet next to the component that needs it. React deduplicates and inserts it in the correct order before revealing Suspense content. |
| `preinit`, `preload`, `prefetchDNS`, `preconnect` (from `react-dom`) | 19 | Imperatively hint the browser to load resources early. Call from render or event handlers. React deduplicates hints across the component tree. |
| React Compiler (`babel-plugin-react-compiler`) | C 1.0 | Build-time automatic memoization for components and hooks. Install, add to Babel/Vite pipeline. Projects typically start with path-based overrides to compile a subset of files. |
| `"use memo"` directive | C 1.0 | Opt a single function into compilation when using `compilationMode: 'annotation'`. Place at the start of the function body. Module-level `"use memo"` at the top of a file compiles all functions in that file. |
| `"use no memo"` directive | C 1.0 | Temporary escape hatch — skip compilation for a specific component or hook that causes a runtime regression. Not a permanent solution. Place at the start of the function body. |
| Compiler-powered ESLint rules | C 1.0 | Rules for purity, refs, set-state-in-render, immutability, etc. now ship in `eslint-plugin-react-hooks` recommended preset. Surface Rules-of-React violations even without the compiler installed. Note: some projects use Biome instead — check project lint config. |
## Key APIs
### `useTransition` and `startTransition` (18)
`useTransition` returns `[isPending, startTransition]`. Wrap any state update that is not directly tied to the user's current gesture inside `startTransition`. React will render the old UI while computing the new one, and `isPending` is `true` during that window.
In React 19, `startTransition` can accept an async function (an "Action"). React sets `isPending` to `true` for the entire duration of the async work, not just during the synchronous part.
```tsx
// 18: synchronous transition
const [isPending, startTransition] = useTransition();
startTransition(() => setQuery(input));
// 19: async Action — isPending stays true until the await settles
startTransition(async () => {
const err = await updateName(name);
if (err) setError(err);
});
```
Use `startTransition` (the module-level export) when you cannot use the hook (outside a component, in a router callback, etc.).
### `useDeferredValue` (18 / 19)
Creates a "lagging" copy of a value. Pass it to a memoized, expensive component so that React can render the stale UI while computing the updated one.
```tsx
// 19: initialValue shows '' on first render; avoids loading flash
const deferred = useDeferredValue(searchQuery, "");
return <Results query={deferred} />; // Results wrapped in memo
```
`deferred !== searchQuery` while the deferred render is in progress — use this to show a "stale" indicator.
### `useActionState` (19)
Replaces the `useState` + `isPending` + `try/catch` + `setError` boilerplate for any async operation that can be retried or submitted as a form.
```tsx
const [error, submitAction, isPending] = useActionState(
async (prevState, formData) => {
const err = await updateName(formData.get("name"));
if (err) return err; // returned value becomes next state
redirect("/profile");
return null;
},
null, // initialState
);
// Use submitAction as the form's action prop or call it directly
<form action={submitAction}>
<input name="name" />
<button disabled={isPending}>Save</button>
{error && <p>{error}</p>}
</form>;
```
### `useOptimistic` (19)
Shows a speculative value immediately while an async Action is in progress. React automatically reverts to the server-confirmed value when the Action resolves or rejects.
```tsx
const [optimisticName, setOptimisticName] = useOptimistic(currentName);
const submit = async (formData) => {
const newName = formData.get("name");
setOptimisticName(newName); // shows immediately
await updateName(newName); // reverts if this throws
};
```
### `use()` (19)
Unlike hooks, `use` can appear after conditional statements. Two primary uses:
**Reading a promise** (must be stable — from a cache, not created inline):
```tsx
function Comments({ commentsPromise }) {
const comments = use(commentsPromise); // suspends until resolved
return comments.map((c) => <p key={c.id}>{c.text}</p>);
}
```
**Reading context after an early return** (hooks cannot appear after `return`):
```tsx
function Heading({ children }) {
if (!children) return null;
const theme = use(ThemeContext); // valid here; hooks would not be
return <h1 style={{ color: theme.color }}>{children}</h1>;
}
```
### `useSyncExternalStore` (18)
The correct way for libraries (and app code) to subscribe to non-React state. Prevents tearing under concurrent rendering.
```tsx
const value = useSyncExternalStore(
store.subscribe, // called when store changes
store.getSnapshot, // returns current value (must be stable reference if unchanged)
store.getServerSnapshot, // optional: for SSR
);
```
## Verifying compiler behavior
The compiler is a black box unless you inspect its output. When reviewing code in compiled paths, run the compiler on the specific code to see what it actually does. Do not guess — verify.
**Run the compiler on a code snippet:**
```sh
cd site && node -e "
const {transformSync} = require('@babel/core');
const code = \`<paste component here>\`;
const diagnostics = [];
const result = transformSync(code, {
plugins: [
['@babel/plugin-syntax-typescript', {isTSX: true}],
['babel-plugin-react-compiler', {
logger: {
logEvent(_, event) {
if (event.kind === 'CompileError' || event.kind === 'CompileSkip') {
diagnostics.push(event.detail?.toString?.()?.substring(0, 200));
}
},
},
}],
],
filename: 'test.tsx',
});
console.log('Compiled:', result.code.includes('_c('));
if (diagnostics.length) console.log('Diagnostics:', diagnostics);
console.log(result.code);
"
```
**Reading compiled output:**
- `const $ = _c(N)` — allocates N memoization cache slots.
- `if ($[n] !== dep)` — cache invalidation guard. Re-computes when `dep` changes (referential equality).
- `if ($[n] === Symbol.for("react.memo_cache_sentinel"))` — one-time initialization. Runs once on first render, cached forever after. This is how the compiler handles expressions with no reactive dependencies.
- `_temp` functions — pure callbacks the compiler hoisted out of the component body.
**Check all compiled files at once:**
```sh
cd site && pnpm run lint:compiler
```
This runs the compiler on every file in the compiled paths and reports CompileError / CompileSkip diagnostics. Zero diagnostics means all functions compiled cleanly.
**What the compiler catches vs. what it does not:**
The compiler emits `CompileError` for mutations of props, state, or hook arguments during render, and for `ref.current` access during render. The project's lint pipeline catches these automatically — do not flag them in review.
The compiler does **not** flag impure function calls during render (`Math.random()`, `Date.now()`, `new Date()`). Instead it silently memoizes them with a sentinel guard, freezing the value after first render. This changes semantics without any diagnostic. Verify suspicious calls by running the compiler and checking for sentinel guards in the output.
## Pitfalls
Things that are easy to get wrong even when you know the modern API exists. Check your output against these.
**Effects run twice in development with StrictMode.** React 18 intentionally mounts → unmounts → remounts every component in dev to surface effects that are not resilient to remounting. This is not a bug. If an effect breaks on the second mount, it is missing a cleanup function. Write `return () => cleanup()` from every effect that sets up a subscription, timer, or external resource.
**Concurrent rendering can call render multiple times.** The render function (component body) may be called more than once before React commits to the DOM. Side effects (mutations, subscriptions, logging) in the render body will run multiple times. Move them into `useEffect` or event handlers.
**Do not create promises during render and pass them to `use()`.** A new promise is created every render, causing an infinite suspend-retry loop. Create the promise outside the component (module level), or use a caching library (SWR, React Query, `cache()` from React) to stabilize it.
**`useOptimistic` reverts automatically — do not fight it.** The optimistic value is a presentation layer only. When the Action settles, React replaces it with the real `currentValue` you passed in. Do not try to sync optimistic state back to your real state; let React handle the revert.
**`flushSync` opts out of automatic batching.** If third-party code or a browser API (e.g. `ResizeObserver`) calls `setState` and you need synchronous DOM flushing, wrap with `flushSync(() => setState(...))`. This is a last resort; prefer letting React batch.
**`forwardRef` still works in React 19 but will be deprecated.** Function components accept `ref` as a plain prop now. New code should use the prop directly. Existing `forwardRef` wrappers continue to work without changes; migrate when convenient.
**`<Activity>` does not unmount.** Content inside a hidden `<Activity>` boundary stays mounted. Effects keep running. Use it for preserving scroll position or form state, not for preventing expensive mounts — use lazy loading for that.
**TypeScript: implicit returns from ref callbacks are now type errors.** In React 19, returning anything other than a cleanup function (or nothing) from a ref callback is rejected by the TypeScript types. The most common case is arrow-function refs that implicitly return the DOM node:
```tsx
// Error in React 19 types:
<div ref={el => (instance = el)} />
// Fix — use a block body:
<div ref={el => { instance = el; }} />
```
**TypeScript: `useRef` now requires an argument.** `useRef<T>()` with no argument is a type error. Pass `undefined` for mutable refs or `null` for DOM refs you initialize on mount: `useRef<T>(undefined)` / `useRef<HTMLDivElement | null>(null)`.
**`useId` output format changed across versions.** React 18 produced `:r0:`. React 19.1 changed it to `«r0»`. React 19.2 changed it again to `_r0`. Do not parse or depend on the specific format — treat it as an opaque string.
**`useFormStatus` reads the nearest parent `<form>` with a function `action`.** It does not reflect native HTML form submissions — only React Actions. A submit button that is a sibling of `<form>` (rather than a descendant) will not see the form's status.
**Context as a provider (`<Context>`) requires React 19; `<Context.Provider>` still works.** Do not use `<Context>` shorthand in a codebase that needs to support React 18. The two forms can coexist during migration.
**Compiler freezes impure expressions silently.** `Math.random()`, `Date.now()`, `new Date()`, and `window.innerWidth` in a component body all compile without diagnostics. The compiler wraps them in a sentinel guard (`Symbol.for("react.memo_cache_sentinel")`) that runs the expression once and caches the result forever. The value never updates on re-render. Fix: move to a `useState` initializer (`useState(() => Math.random())`), `useEffect`, or event handler.
**Component granularity affects compiler optimization.** When one pattern in a component causes a `CompileError` (e.g., a necessary `ref.current` read during render), the compiler skips the **entire** component. If the rest of the component would benefit from compilation, extract the non-compilable pattern into a small child component. This keeps the parent compiled.
**The compiler only memoizes components and hooks.** Standalone utility functions (even expensive ones called during render) are not compiled. If a utility function is truly expensive, it still needs its own caching strategy outside of React (e.g., a module-level cache, `WeakMap`, etc.).
**Changing memoization can shift `useEffect` firing.** A value that was unstable before compilation may become stable after, causing an effect that depended on it to fire less often. Conversely, future compiler changes may alter memoization granularity. Effects that use memoized values as dependencies should be resilient to these changes — they should be true synchronization effects, not "run this when X changes" hacks.
## Behavioral changes that affect code
- **Automatic batching** (18): State updates in `setTimeout`, `Promise.then`, `addEventListener` callbacks, etc. are now batched into a single re-render. Previously only React synthetic event handlers were batched. Code that relied on unbatched updates (reading DOM synchronously after each `setState`) must use `flushSync`.
- **StrictMode double-invoke** (18): In development, every component is mounted → unmounted → remounted with the previous state. Every effect runs cleanup → setup twice on initial mount. `useMemo` and `useCallback` also double-invoke their functions. Production behavior is unchanged. If a test or component breaks under this, the component had a latent cleanup bug.
- **StrictMode ref double-invoke** (19): In development, ref callbacks are also invoked twice on mount (attach → detach → attach). Return a cleanup function from the ref callback to handle detach correctly.
- **StrictMode memoization reuse** (19): During the second pass of double-rendering, `useMemo` and `useCallback` now reuse the cached result from the first pass instead of calling the function again. Components that are already StrictMode-compatible should not notice a difference.
- **Suspense fallback commits immediately** (19): When a component suspends, React now commits the nearest `<Suspense>` fallback without waiting for sibling trees to finish rendering. After the fallback is shown, React "pre-warms" suspended siblings in the background. This makes fallbacks appear faster but changes the order of rendering work.
- **Error re-throwing removed** (19): Errors that are not caught by an Error Boundary are now reported to `window.reportError` (not re-thrown). Errors caught by an Error Boundary go to `console.error` once. If your production monitoring relied on the re-thrown error, add handlers to `createRoot`: `createRoot(el, { onUncaughtError, onCaughtError })`.
- **Transitions in `popstate` are synchronous** (19): Browser back/forward navigation triggers synchronous transition flushing. This ensures the URL and UI update together atomically during history navigation.
- **`useEffect` from discrete events flushes synchronously** (18): Effects triggered by a click or keydown (discrete events) are now flushed synchronously before the browser paints, consistent with `useLayoutEffect` for those cases.
- **Hydration mismatches treated as errors** (18 / improved in 19): Text content mismatches between server HTML and client render revert to client rendering up to the nearest `<Suspense>` boundary. React 19 logs a single diff instead of multiple warnings, making mismatches much easier to diagnose.
- **New JSX transform required** (19): The automatic JSX runtime introduced in 2020 (`react/jsx-runtime`) is now mandatory. The classic transform (which required `import React from 'react'` in every file) is no longer supported. Most toolchains have already shipped the new transform; check your Babel or TypeScript config if you see warnings.
- **UMD builds removed** (19): React no longer ships UMD bundles. Load via npm and a bundler, or use an ESM CDN (`import React from "https://esm.sh/react@19"`).
- **React Compiler automatic memoization** (Compiler 1.0): Build-time Babel plugin that inserts memoization into components and hooks. Components that follow the Rules of React are automatically memoized; components that violate them are silently skipped (no build error, no runtime change). The compiler can memoize conditionally and after early returns — things impossible with manual `useMemo`/`useCallback`. Works with React 17+ via `react-compiler-runtime`; best with React 19+. Projects adopt incrementally via path-based Babel overrides, `compilationMode: 'annotation'`, or the `"use memo"` / `"use no memo"` directives. Check the project's Vite/Babel config to know which paths are compiled. Compiled components show a "Memo ✨" badge in React DevTools.
@@ -1,199 +0,0 @@
# Modern TypeScript (5.06.0 RC) — Reference
Reference for writing idiomatic TypeScript. Covers what changed, what it replaced, and what to reach for. Respect the project's minimum TypeScript version: don't emit features from a version newer than what the project targets. Check `package.json` and `tsconfig.json` before writing code.
## How modern TypeScript thinks differently
The 5.x era resolves years of module system ambiguity and cleans house on legacy options. Three themes dominate:
**Module semantics are explicit.** `--verbatimModuleSyntax` (5.0) makes import/export intent visible in source: type imports must carry `type`, value imports stay. Combined with `--module preserve` or `--moduleResolution bundler`, the compiler now accurately models what bundlers and modern runtimes actually do. `import defer` (5.9) extends the model to deferred evaluation.
**Resource lifetimes are first-class.** `using` and `await using` (5.2) provide deterministic cleanup without `try/finally`. Any object implementing `Symbol.dispose` participates. `DisposableStack` handles ad-hoc multi-resource cleanup in functions where creating a full class is overkill.
**Inference is smarter about what it knows.** Inferred type predicates (5.5) let `.filter(x => x !== undefined)` produce `T[]` instead of `(T | undefined)[]` automatically. `NoInfer<T>` (5.4) gives library authors precise control over which parameters drive inference. Narrowing now survives closures after last assignment, constant indexed accesses, and `switch (true)` patterns.
**TypeScript 6.0 is a transition release toward 7.0** (the Go-native port). It turns years of soft deprecations into errors and changes several defaults. Most impactful: `types` defaults to `[]` (must list `@types` packages explicitly), `rootDir` defaults to `.`, `strict` defaults to `true`, `module` defaults to `esnext`. Projects relying on implicit behavior need explicit config. Check the deprecations section before upgrading.
## Replace these patterns
The left column reflects patterns still common before TypeScript 5.x. Write the right column instead. The "Since" column tells you the minimum TypeScript version required.
| Old pattern | Modern replacement | Since |
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------------------------------- | ------ |
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)``const` modifier on type parameter | 5.0 |
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
| `new RegExp(str.replace(/[.\*+?^${}() | [\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
## New capabilities
These enable things that weren't practical before. Reach for them in the described situations.
| What | Since | When to use it |
| ----------------------------------------------- | ------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `using` / `await using` declarations | 5.2 | Any resource needing deterministic cleanup (file handles, DB connections, locks, event listeners). Object must implement `Symbol.dispose` / `Symbol.asyncDispose`. |
| `DisposableStack` / `AsyncDisposableStack` | 5.2 | Ad-hoc multi-resource cleanup without creating a class. Call `.defer(fn)` right after acquiring each resource. Stack disposes in LIFO order. |
| `const` modifier on type parameters | 5.0 | Force `const`-like (literal/readonly tuple) inference at call sites without requiring callers to write `as const`. Constraint must use `readonly` arrays. |
| Decorator metadata (`Symbol.metadata`) | 5.2 | Attach and read per-class metadata from decorators via `context.metadata`. Retrieved as `MyClass[Symbol.metadata]`. Requires `Symbol.metadata ??= Symbol(...)` polyfill. |
| `NoInfer<T>` utility type | 5.4 | Prevent a parameter from contributing inference candidates for `T`. Use when one argument should be the "source of truth" and others should only be checked against it. |
| Inferred type predicates | 5.5 | Filter callbacks that test for `!== null` or `instanceof` now automatically produce a type predicate. `Array.prototype.filter` then narrows the result array type. |
| `--isolatedDeclarations` | 5.5 | Require explicit return types on exported declarations. Unlocks parallel declaration emit by external tooling (esbuild, oxc, etc.) without needing a full type-checker pass. |
| `${configDir}` in tsconfig paths | 5.5 | Anchor `typeRoots`, `paths`, `outDir`, etc. in a shared base tsconfig to the _consuming_ project's directory, not the shared file's location. |
| Always-truthy/nullish check errors | 5.6 | Catches regex literals in `if`, arrow functions as comparators, `?? 100` on non-nullable left side, misplaced parentheses. No API to call; existing bugs now surface as errors. |
| Iterator helper methods (`IteratorObject`) | 5.6 | Built-in iterators from `Map`, `Set`, generators, etc. now have `.map()`, `.filter()`, `.take()`, `.drop()`, `.flatMap()`, `.toArray()`, `.reduce()`, etc. Use `Iterator.from(iterable)` to wrap any iterable. |
| `--noUncheckedSideEffectImports` | 5.6 | Error when a side-effect import (`import "..."`) resolves to nothing. Catches typos in polyfill or CSS imports. |
| `--noCheck` | 5.6 | Skip type checking entirely during emit. Useful for separating "fast emit" from "thorough check" pipeline stages, especially with `--isolatedDeclarations`. |
| `--rewriteRelativeImportExtensions` | 5.7 | Rewrite `.ts``.js`, `.tsx``.jsx`, `.mts``.mjs`, `.cts``.cjs` in relative imports during emit. Required when writing `.ts` imports for Node.js strip-types mode and still needing `.js` output for library distribution. |
| `--erasableSyntaxOnly` | 5.8 | Error on constructs that can't be type-stripped by Node.js `--experimental-strip-types`: `enum`, `namespace` with code, parameter properties, `import =` aliases. |
| `require()` of ESM under `--module nodenext` | 5.8 | Node.js 22+ allows CJS to `require()` ESM files (no top-level `await`). TypeScript now allows this under `nodenext` without error. |
| `import defer * as ns from "..."` | 5.9 | Defer module _evaluation_ (not loading) until first property access. Module is loaded and verified at import time; side-effects are delayed. Only works with `--module preserve` or `esnext`. |
| `Set` algebra methods | 5.5 | Non-mutating: `union`, `intersection`, `difference`, `symmetricDifference` → new `Set`. Predicate: `isSubsetOf`, `isSupersetOf`, `isDisjointFrom``boolean`. Requires `esnext` or `es2025` lib. |
| `Object.groupBy` / `Map.groupBy` | 5.4 | Group an iterable into buckets by key function. Return type has all keys as optional (not every key is guaranteed present). Requires `esnext` or `es2024`+ lib. |
| `Temporal` API types | 6.0 RC | `Temporal.Now`, `Temporal.Instant`, `Temporal.PlainDate`, etc. Available under `esnext` or `esnext.temporal` lib. Usable in runtimes that already ship it (V8 118+, SpiderMonkey, etc.). |
| `@satisfies` in JSDoc | 5.0 | Validates that a JS expression satisfies a type without widening it — the TS `satisfies` operator for `.js` files. Write `/** @satisfies {MyType} */` above the declaration or inline on a parenthesized expression. |
| `@overload` in JSDoc | 5.0 | Declare multiple call signatures for a JS function. Each JSDoc comment tagged `@overload` is treated as a distinct overload; the final JSDoc comment (without `@overload`) describes the implementation signature. |
| Getter/setter with completely unrelated types | 5.1 | `get style(): CSSStyleDeclaration` and `set style(v: string)` can now have fully unrelated types, provided both have explicit type annotations. Previously the getter type was required to be a subtype of the setter type. |
| `instanceof` narrowing via `Symbol.hasInstance` | 5.3 | When a class defines `static [Symbol.hasInstance](val: unknown): val is T`, the `instanceof` operator now narrows to the predicate type `T`, not the class type itself. Useful when the runtime check and the structural type differ. |
| Regex literal syntax checking | 5.5 | TypeScript validates regex literal syntax: malformed groups, nonexistent backreferences, named capture mismatches, and features not available at the current `--target`. No API needed; existing latent bugs surface as errors automatically. |
| `--build` continues past intermediate errors | 5.6 | `tsc --build` no longer stops at the first failing project. All projects are built and errors reported together. Use `--stopOnBuildErrors` to restore the old stop-on-first-error behavior. Useful for monorepos during upgrades. |
| `--module node18` | 5.8 | Stable `--module` flag for Node.js 18 semantics: disallows `require()` of ESM (unlike `nodenext`) and still allows import assertions. Use when pinned to Node 18 and not ready for `nodenext` behavior changes. |
| `--module node20` | 5.9 | Stable `--module` flag for Node.js 20 semantics: permits `require()` of ESM, rejects import assertions. Implies `--target es2023` (unlike `nodenext`, which floats to `esnext`). |
## Key APIs
### `Disposable` / `AsyncDisposable` / stacks (5.2)
Global types provided by TypeScript's lib (requires `esnext.disposable` or `esnext` in `lib`):
- `Disposable``{ [Symbol.dispose](): void }`
- `AsyncDisposable``{ [Symbol.asyncDispose](): PromiseLike<void> }`
- `DisposableStack``defer(fn)`, `use(resource)`, `adopt(value, disposeFn)`, `move()`. Is itself `Disposable`.
- `AsyncDisposableStack` — async equivalent. Is itself `AsyncDisposable`.
- `SuppressedError` — thrown when both the scope body and a `[Symbol.dispose]` throw. `.error` holds the dispose-phase error; `.suppressed` holds the original error.
Polyfill the symbols in older runtimes:
```ts
Symbol.dispose ??= Symbol("Symbol.dispose");
Symbol.asyncDispose ??= Symbol("Symbol.asyncDispose");
```
### Decorator context types (5.0)
Each decorator kind receives a typed context object as its second parameter:
- `ClassDecoratorContext`
- `ClassMethodDecoratorContext`
- `ClassGetterDecoratorContext`
- `ClassSetterDecoratorContext`
- `ClassFieldDecoratorContext`
- `ClassAccessorDecoratorContext`
All context objects have `.name`, `.kind`, `.static`, `.private`, and `.metadata`. Method/getter/setter/accessor contexts also have `.addInitializer(fn)` for running code at construction time.
### `IteratorObject` (5.6)
`IteratorObject<T, TReturn, TNext>` is the new type for built-in iterable iterators. Key methods: `map`, `filter`, `take`, `drop`, `flatMap`, `forEach`, `reduce`, `some`, `every`, `find`, `toArray`. Not the same as the pre-existing structural `Iterator<T>` protocol.
- Generators produce `Generator<T>` which extends `IteratorObject`.
- `Map.prototype.entries()` returns `MapIterator<[K, V]>`, `Set.prototype.values()` returns `SetIterator<T>`, etc.
- `Iterator.from(iterable)` converts any `Iterable` to an `IteratorObject`.
- `AsyncIteratorObject` exists for async parity.
- `--strictBuiltinIteratorReturn` (new `--strict`-mode flag in 5.6) makes the return type of `BuiltinIteratorReturn` be `undefined` instead of `any`, catching unchecked `done` access.
### Array copying methods (5.2)
Declared on `Array`, `ReadonlyArray`, and all `TypedArray` types. Use these instead of the mutating variants when you need to preserve the original:
| Mutating | Non-mutating copy |
| ---------------------------------- | ------------------------------------- |
| `arr.sort(cmp)` | `arr.toSorted(cmp)` |
| `arr.reverse()` | `arr.toReversed()` |
| `arr.splice(start, del, ...items)` | `arr.toSpliced(start, del, ...items)` |
| `arr[i] = v` | `arr.with(i, v)` |
## Pitfalls
Things easy to get wrong even when you know the modern API exists. Check your output against these.
**tsconfig defaults changed hard in 6.0.** `types: []` means no `@types/*` packages load implicitly. If you see floods of "cannot find name 'process'" or "cannot find module 'fs'" after upgrading to 6.0, add `"types": ["node"]` (or whatever you need) to `compilerOptions`. `rootDir: "."` means a project with source in `src/` will emit to `dist/src/` instead of `dist/` — add `"rootDir": "./src"` explicitly. `strict: true` by default means projects with loose code see new errors.
**`using` requires a runtime polyfill on older runtimes.** `Symbol.dispose` and `Symbol.asyncDispose` don't exist before Node.js 18.x / Chrome 120. Add the two-line polyfill at your entry point. `DisposableStack` and `AsyncDisposableStack` need a more substantial polyfill (e.g. from `@microsoft/using-polyfill`).
**`using` disposes in LIFO order.** Resources declared later in a scope are disposed first. Declare in the order you want reversed cleanup (acquisition order). `DisposableStack.defer` also runs in LIFO order.
**Inferred type predicates have if-and-only-if semantics.** `x => !!x` does NOT infer `x is NonNullable<T>` because `0`, `""`, and `false` are falsy but not absent. TypeScript correctly refuses the predicate. Use `x => x !== undefined` or `x => x !== null` for precise null/undefined filters. If a predicate isn't being inferred, the false branch is probably ambiguous.
**`--verbatimModuleSyntax` breaks CJS `require` emit.** Under this flag ESM `import`/`export` is emitted verbatim. You cannot produce `require()` calls from standard `import` syntax. For CJS output you must use `import foo = require("foo")` and `export = { ... }` syntax explicitly.
**`NoInfer<T>` doesn't prevent `T` from being resolved, only from being contributed at that position.** Other parameters can still infer `T`. It means "don't use me as an inference candidate", not "block `T` from being resolved".
**`--isolatedDeclarations` requires explicit return types on all exports.** Exported arrow functions, function declarations, and class methods all need annotations if their return type isn't trivially inferrable from a literal or type assertion. Editor quick-fixes can add them automatically.
**Standard decorators are incompatible with `--experimentalDecorators`.** Different type signatures, metadata model, and emit. A decorator written for one will not work with the other. `--emitDecoratorMetadata` is not supported with standard decorators. Don't mix the two systems in one project.
**`import defer` does not downlevel.** TypeScript does not transform `import defer` to polyfill-compatible code. The module is still _loaded_ eagerly (must exist); only _evaluation_ is deferred. Only use it under `--module preserve` or `esnext` with a runtime or bundler that supports it.
**`--erasableSyntaxOnly` prohibits parameter properties.** `constructor(public x: number)` is not allowed. Expand to an explicit field declaration plus assignment in the constructor body.
**Closure narrowing is invalidated if the variable is assigned anywhere in a nested function.** TypeScript cannot know when a nested function will run, so any assignment to a `let`/param inside a nested function — even a no-op like `value = value` — invalidates narrowing for all closures in the outer scope. Only the outer "no further assignments after this point" pattern is safe.
**Constant indexed access narrowing requires both `obj` and `key` to be unmodified between the check and the use.** If either is a `let` that could be reassigned, TypeScript will not narrow `obj[key]`. Extract the value to a `const` in that case.
**`switch (true)` narrowing does not carry across fall-through cases.** In a `switch (true)`, each `case` condition narrows independently. A variable narrowed in `case typeof x === "string":` that falls through to the next case will have its narrowing widened by the next condition, not accumulated from the previous one.
**`const` type parameter modifier falls back when constraint is mutable.** `<const T extends string[]>(args: T)` falls back to `string[]` because `readonly ["a", "b"]` isn't assignable to `string[]`. Use `<const T extends readonly string[]>` for arrays.
**`assert` import syntax errors under `--module nodenext` since 5.8.** Any remaining `import x from "..." assert { ... }` must be updated to `import x from "..." with { ... }`.
**`Array.prototype.filter(x => x !== null)` now narrows to non-null (5.5).** This is almost always correct, but if you intentionally needed the nullable type downstream, add an explicit annotation: `const items: (T | null)[] = arr.filter(x => x !== null)`.
## Behavioral changes that affect code
- **All enums are union enums** (5.0): Every enum member gets its own literal type. Out-of-domain literal assignment to an enum type now errors. Cross-enum assignment between enums with identical names but differing values now errors.
- **Relational operators no longer allow implicit string/number coercions** (5.0): `ns > 4` where `ns: number | string` is a type error. Use `+ns > 4` to explicitly coerce.
- **`--module`/`--moduleResolution` must agree on node flavor** (5.2): Mixing `--module nodenext` with `--moduleResolution bundler` is an error. Use `--module nodenext` alone or `--module esnext --moduleResolution bundler`.
- **Deprecations from 5.0 become hard errors in 5.5**: `--importsNotUsedAsValues`, `--preserveValueImports`, `--target ES3`, `--out`, and several others are fully removed in 5.5. They can no longer be specified, even with `"ignoreDeprecations": "5.0"`. Migrate to `--verbatimModuleSyntax` for the import flags.
- **Type-only imports conflicting with local values** (5.4): Under `--isolatedModules`, `import { Foo } from "..."` where a local `let Foo` also exists now errors. Use `import type { Foo }` or `import { type Foo }`.
- **Reference directives no longer synthesized or preserved in declaration emit** (5.5): `/// <reference types="node" />` TypeScript used to add automatically is no longer emitted. User-written directives are dropped unless they carry `preserve="true"`. Update library `tsconfig.json` if you relied on this.
- **`.mts` files never emit CJS; `.cts` files never emit ESM** (5.6): Regardless of `--module` setting. Previously the extension was ignored in some modes.
- **JSON imports under `--module nodenext` require `with { type: "json" }`** (5.7): `import data from "./config.json"` without the attribute is now a type error.
- **`TypedArray`s are now generic** (5.7): `Uint8Array` is `Uint8Array<TArrayBuffer extends ArrayBufferLike = ArrayBufferLike>`. Code passing `Buffer` (from `@types/node`) to typed-array parameters may see new errors. Update `@types/node` to a version that matches.
- **`import assert { ... }` is an error under `--module nodenext`** (5.8): Node.js 22 dropped support for the old syntax. Use `with { ... }`.
- **`types` defaults to `[]` in 6.0**: All implicit `@types/*` loading stops. Add an explicit `"types": ["node"]` or the array will remain empty. Using `"types": ["*"]` restores the 5.x behavior.
- **`rootDir` defaults to `.` (the tsconfig directory) in 6.0**: Previously inferred from the common ancestor of all source files. Projects with `"include": ["./src"]` and no explicit `rootDir` will now emit into `dist/src/` instead of `dist/`. Add `"rootDir": "./src"` to fix.
- **`strict` defaults to `true` in 6.0**: Projects that were implicitly not strict will see new errors. Set `"strict": false` explicitly if you're not ready to fix them.
- **`--baseUrl` deprecated in 6.0** and no longer acts as a module resolution root. Add explicit prefixes to your `paths` entries instead.
- **`--moduleResolution node` (node10) deprecated in 6.0**: Removed in 7.0. Migrate to `nodenext` or `bundler`.
- **`amd`, `umd`, `systemjs`, `none` module targets deprecated in 6.0**: Removed in 7.0. Migrate to a bundler.
- **`--outFile` removed in 6.0**: Use a bundler (esbuild, Rollup, Webpack, etc.).
- **`module Foo { }` syntax removed in 6.0**: Rename all such declarations to `namespace Foo { }`.
- **`--esModuleInterop false` and `--allowSyntheticDefaultImports false` removed in 6.0**: Safe interop is now always on. Default imports from CJS modules (`import express from "express"`) are always valid.
- **Explicit `typeRoots` disables upward `node_modules/@types` fallback** (5.1): When `typeRoots` is specified and a lookup fails in those directories, TypeScript no longer walks parent directories for `@types`. If you relied on the fallback, add `"./node_modules/@types"` explicitly to your `typeRoots` array.
- **`super.` on instance field properties is a type error** (5.3): Calling `super.foo()` where `foo` is a class field (arrow function assigned in the constructor) rather than a prototype method now errors. Instance fields don't exist on the prototype; `super.field` is `undefined` at runtime.
- **`--build` always emits `.tsbuildinfo`** (5.6): Previously only written when `--incremental` or `--composite` was set. Now written unconditionally in any `--build` invocation. Update `.gitignore` or CI artifact management if needed.
- **`.mts`/`.cts` extensions and `package.json` `"type"` respected in all module modes** (5.6): Format-specific extensions and the `"type"` field inside `node_modules` are now honored regardless of `--module` setting (except `amd`, `umd`, `system`). A `.mts` file will never emit CJS output even under `--module commonjs`.
- **Granular return expression checking** (5.8): Each branch of a conditional expression (`cond ? a : b`) directly inside a `return` statement is now checked individually against the declared return type. Previously an `any`-typed branch could silently suppress type errors in the other branch.
+2 -2
View File
@@ -5,6 +5,6 @@ runs:
using: "composite"
steps:
- name: Install syft
uses: anchore/sbom-action/download-syft@e22c389904149dbc22b58101806040fa8d37a610 # v0.24.0
uses: anchore/sbom-action/download-syft@f325610c9f50a54015d37c8d16cb3b0e2c8f4de0 # v0.18.0
with:
syft-version: "v1.26.1"
syft-version: "v1.20.0"
+1 -1
View File
@@ -4,7 +4,7 @@ description: |
inputs:
version:
description: "The Go version to use."
default: "1.25.8"
default: "1.25.7"
use-cache:
description: "Whether to use the cache."
default: "true"
+6 -4
View File
@@ -31,8 +31,7 @@ updates:
patterns:
- "golang.org/x/*"
ignore:
# Patch updates are handled by the security-patch-prs workflow so this
# lane stays focused on broader dependency updates.
# Ignore patch updates for all dependencies
- dependency-name: "*"
update-types:
- version-update:semver-patch
@@ -57,7 +56,7 @@ updates:
labels: []
ignore:
# We need to coordinate terraform updates with the version hardcoded in
# our Go code. These are handled by the security-patch-prs workflow.
# our Go code.
- dependency-name: "terraform"
- package-ecosystem: "npm"
@@ -83,6 +82,9 @@ updates:
mui:
patterns:
- "@mui*"
radix:
patterns:
- "@radix-ui/*"
react:
patterns:
- "react"
@@ -118,11 +120,11 @@ updates:
interval: "weekly"
commit-message:
prefix: "chore"
labels: []
groups:
coder-modules:
patterns:
- "coder/*/coder"
labels: []
ignore:
- dependency-name: "*"
update-types:
+118 -46
View File
@@ -35,7 +35,7 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -157,7 +157,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -181,7 +181,7 @@ jobs:
echo "LINT_CACHE_DIR=$dir" >> "$GITHUB_ENV"
- name: golangci-lint cache
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
with:
path: |
${{ env.LINT_CACHE_DIR }}
@@ -204,7 +204,7 @@ jobs:
# Needed for helm chart linting
- name: Install helm
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # v5.0.0
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # v4.3.1
with:
version: v3.9.2
continue-on-error: true
@@ -247,7 +247,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -272,7 +272,7 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -327,7 +327,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -379,7 +379,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -575,7 +575,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -637,7 +637,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -709,7 +709,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -736,7 +736,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -769,7 +769,7 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -849,7 +849,7 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -870,7 +870,7 @@ jobs:
# the check to pass. This is desired in PRs, but not in mainline.
- name: Publish to Chromatic (non-mainline)
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@f191a0224b10e1a38b2091cefb7b7a2337009116 # v16.0.0
uses: chromaui/action@07791f8243f4cb2698bf4d00426baf4b2d1cb7e0 # v13.3.5
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -902,7 +902,7 @@ jobs:
# infinitely "in progress" in mainline unless we re-review each build.
- name: Publish to Chromatic (mainline)
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@f191a0224b10e1a38b2091cefb7b7a2337009116 # v16.0.0
uses: chromaui/action@07791f8243f4cb2698bf4d00426baf4b2d1cb7e0 # v13.3.5
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -930,7 +930,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -1005,7 +1005,7 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -1043,7 +1043,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -1097,7 +1097,7 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -1316,50 +1316,122 @@ jobs:
"${IMAGE}"
done
- name: Resolve Docker image digests for attestation
id: docker_digests
if: github.ref == 'refs/heads/main'
continue-on-error: true
env:
IMAGE_BASE: ghcr.io/coder/coder-preview
BUILD_TAG: ${{ steps.build-docker.outputs.tag }}
run: |
set -euxo pipefail
main_digest=$(docker buildx imagetools inspect --raw "${IMAGE_BASE}:main" | sha256sum | awk '{print "sha256:"$1}')
echo "main_digest=${main_digest}" >> "$GITHUB_OUTPUT"
latest_digest=$(docker buildx imagetools inspect --raw "${IMAGE_BASE}:latest" | sha256sum | awk '{print "sha256:"$1}')
echo "latest_digest=${latest_digest}" >> "$GITHUB_OUTPUT"
version_digest=$(docker buildx imagetools inspect --raw "${IMAGE_BASE}:${BUILD_TAG}" | sha256sum | awk '{print "sha256:"$1}')
echo "version_digest=${version_digest}" >> "$GITHUB_OUTPUT"
# GitHub attestation provides SLSA provenance for the Docker images, establishing a verifiable
# record that these images were built in GitHub Actions with specific inputs and environment.
# This complements our existing cosign attestations which focus on SBOMs.
#
# We attest each tag separately to ensure all tags have proper provenance records.
# TODO: Consider refactoring these steps to use a matrix strategy or composite action to reduce duplication
# while maintaining the required functionality for each tag.
- name: GitHub Attestation for Docker image
id: attest_main
if: github.ref == 'refs/heads/main' && steps.docker_digests.outputs.main_digest != ''
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
with:
subject-name: ghcr.io/coder/coder-preview
subject-digest: ${{ steps.docker_digests.outputs.main_digest }}
subject-name: "ghcr.io/coder/coder-preview:main"
predicate-type: "https://slsa.dev/provenance/v1"
predicate: |
{
"buildType": "https://github.com/actions/runner-images/",
"builder": {
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
},
"invocation": {
"configSource": {
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
"digest": {
"sha1": "${{ github.sha }}"
},
"entryPoint": ".github/workflows/ci.yaml"
},
"environment": {
"github_workflow": "${{ github.workflow }}",
"github_run_id": "${{ github.run_id }}"
}
},
"metadata": {
"buildInvocationID": "${{ github.run_id }}",
"completeness": {
"environment": true,
"materials": true
}
}
}
push-to-registry: true
- name: GitHub Attestation for Docker image (latest tag)
id: attest_latest
if: github.ref == 'refs/heads/main' && steps.docker_digests.outputs.latest_digest != ''
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
with:
subject-name: ghcr.io/coder/coder-preview
subject-digest: ${{ steps.docker_digests.outputs.latest_digest }}
subject-name: "ghcr.io/coder/coder-preview:latest"
predicate-type: "https://slsa.dev/provenance/v1"
predicate: |
{
"buildType": "https://github.com/actions/runner-images/",
"builder": {
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
},
"invocation": {
"configSource": {
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
"digest": {
"sha1": "${{ github.sha }}"
},
"entryPoint": ".github/workflows/ci.yaml"
},
"environment": {
"github_workflow": "${{ github.workflow }}",
"github_run_id": "${{ github.run_id }}"
}
},
"metadata": {
"buildInvocationID": "${{ github.run_id }}",
"completeness": {
"environment": true,
"materials": true
}
}
}
push-to-registry: true
- name: GitHub Attestation for version-specific Docker image
id: attest_version
if: github.ref == 'refs/heads/main' && steps.docker_digests.outputs.version_digest != ''
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
with:
subject-name: ghcr.io/coder/coder-preview
subject-digest: ${{ steps.docker_digests.outputs.version_digest }}
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
predicate-type: "https://slsa.dev/provenance/v1"
predicate: |
{
"buildType": "https://github.com/actions/runner-images/",
"builder": {
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
},
"invocation": {
"configSource": {
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
"digest": {
"sha1": "${{ github.sha }}"
},
"entryPoint": ".github/workflows/ci.yaml"
},
"environment": {
"github_workflow": "${{ github.workflow }}",
"github_run_id": "${{ github.run_id }}"
}
},
"metadata": {
"buildInvocationID": "${{ github.run_id }}",
"completeness": {
"environment": true,
"materials": true
}
}
}
push-to-registry: true
# Report attestation failures but don't fail the workflow
@@ -1479,7 +1551,7 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -23,7 +23,7 @@ jobs:
steps:
- name: Dependabot metadata
id: metadata
uses: dependabot/fetch-metadata@ffa630c65fa7e0ecfa0625b5ceda64399aea1b36 # v3.0.0
uses: dependabot/fetch-metadata@21025c705c08248db411dc16f3619e6b5f9ea21a # v2.5.0
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
+4 -4
View File
@@ -36,7 +36,7 @@ jobs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -65,7 +65,7 @@ jobs:
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -95,7 +95,7 @@ jobs:
AWS_DOGFOOD_DEPLOY_REGION: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
- name: Set up Flux CLI
uses: fluxcd/flux2/action@871be9b40d53627786d3a3835a3ddba1e3234bd2 # v2.8.3
uses: fluxcd/flux2/action@8454b02a32e48d775b9f563cb51fdcb1787b5b93 # v2.7.5
with:
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
version: "2.8.2"
@@ -142,7 +142,7 @@ jobs:
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+9 -29
View File
@@ -240,7 +240,6 @@ jobs:
- name: Create Coder Task for Documentation Check
if: steps.check-secrets.outputs.skip != 'true'
id: create_task
continue-on-error: true
uses: ./.github/actions/create-task-action
with:
coder-url: ${{ secrets.DOC_CHECK_CODER_URL }}
@@ -255,21 +254,8 @@ jobs:
github-issue-url: ${{ steps.determine-context.outputs.pr_url }}
comment-on-issue: false
- name: Handle Task Creation Failure
if: steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome != 'success'
run: |
{
echo "## Documentation Check Task"
echo ""
echo "⚠️ The external Coder task service was unavailable, so this"
echo "advisory documentation check did not run."
echo ""
echo "Maintainers can rerun the workflow or trigger it manually"
echo "after the service recovers."
} >> "${GITHUB_STEP_SUMMARY}"
- name: Write Task Info
if: steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
if: steps.check-secrets.outputs.skip != 'true'
env:
TASK_CREATED: ${{ steps.create_task.outputs.task-created }}
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
@@ -287,7 +273,7 @@ jobs:
} >> "${GITHUB_STEP_SUMMARY}"
- name: Wait for Task Completion
if: steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
if: steps.check-secrets.outputs.skip != 'true'
id: wait_task
env:
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
@@ -377,7 +363,7 @@ jobs:
fi
- name: Fetch Task Logs
if: always() && steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
if: always() && steps.check-secrets.outputs.skip != 'true'
env:
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
run: |
@@ -390,7 +376,7 @@ jobs:
echo "::endgroup::"
- name: Cleanup Task
if: always() && steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
if: always() && steps.check-secrets.outputs.skip != 'true'
env:
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
run: |
@@ -404,7 +390,6 @@ jobs:
- name: Write Final Summary
if: always() && steps.check-secrets.outputs.skip != 'true'
env:
CREATE_TASK_OUTCOME: ${{ steps.create_task.outcome }}
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
TASK_MESSAGE: ${{ steps.wait_task.outputs.task_message }}
RESULT_URI: ${{ steps.wait_task.outputs.result_uri }}
@@ -415,15 +400,10 @@ jobs:
echo "---"
echo "### Result"
echo ""
if [[ "${CREATE_TASK_OUTCOME}" == "success" ]]; then
echo "**Status:** ${TASK_MESSAGE:-Task completed}"
if [[ -n "${RESULT_URI}" ]]; then
echo "**Comment:** ${RESULT_URI}"
fi
echo ""
echo "Task \`${TASK_NAME}\` has been cleaned up."
else
echo "**Status:** Skipped because the external Coder task"
echo "service was unavailable."
echo "**Status:** ${TASK_MESSAGE:-Task completed}"
if [[ -n "${RESULT_URI}" ]]; then
echo "**Comment:** ${RESULT_URI}"
fi
echo ""
echo "Task \`${TASK_NAME}\` has been cleaned up."
} >> "${GITHUB_STEP_SUMMARY}"
+1 -1
View File
@@ -38,7 +38,7 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+2 -2
View File
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -125,7 +125,7 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+28 -73
View File
@@ -4,20 +4,23 @@ on:
push:
branches:
- main
- "release/2.[0-9]+"
# This event reads the workflow from the default branch (main), not the
# release branch. No cherry-pick needed.
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#release
release:
types: [published]
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
# Queue rather than cancel so back-to-back pushes to main don't cancel the first sync.
cancel-in-progress: false
cancel-in-progress: true
jobs:
sync-main:
name: Sync issues to next Linear release
if: github.event_name == 'push' && github.ref_name == 'main'
sync:
name: Sync issues to Linear release
if: github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -25,86 +28,38 @@ jobs:
fetch-depth: 0
persist-credentials: false
- name: Detect next release version
id: version
# Find the highest release/2.X branch (exact pattern, no suffixes
# like release/2.31_hotfix) and derive the next minor version for
# the release currently in development on main.
run: |
LATEST_MINOR=$(git branch -r | grep -E '^\s*origin/release/2\.[0-9]+$' | \
sed 's/.*release\/2\.//' | sort -n | tail -1)
if [ -z "$LATEST_MINOR" ]; then
echo "No release branch found, skipping sync."
echo "skip=true" >> "$GITHUB_OUTPUT"
exit 0
fi
NEXT="2.$((LATEST_MINOR + 1))"
echo "version=$NEXT" >> "$GITHUB_OUTPUT"
echo "skip=false" >> "$GITHUB_OUTPUT"
echo "Detected next release: $NEXT"
- name: Sync issues
id: sync
if: steps.version.outputs.skip != 'true'
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0.5.0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: sync
version: ${{ steps.version.outputs.version }}
name: ${{ steps.version.outputs.version }}
timeout: 300
sync-release-branch:
name: Sync backports to Linear release
if: github.event_name == 'push' && startsWith(github.ref_name, 'release/')
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Print release URL
if: steps.sync.outputs.release-url
run: echo "Synced to $RELEASE_URL"
env:
RELEASE_URL: ${{ steps.sync.outputs.release-url }}
- name: Extract release version
id: version
# The trigger only allows exact release/2.X branch names.
run: |
echo "version=${GITHUB_REF_NAME#release/}" >> "$GITHUB_OUTPUT"
- name: Sync issues
id: sync
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: sync
version: ${{ steps.version.outputs.version }}
name: ${{ steps.version.outputs.version }}
timeout: 300
code-freeze:
name: Move Linear release to Code Freeze
needs: sync-release-branch
if: >
github.event_name == 'push' &&
startsWith(github.ref_name, 'release/') &&
github.event.created == true
complete:
name: Complete Linear release
if: github.event_name == 'release'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Extract release version
id: version
run: |
echo "version=${GITHUB_REF_NAME#release/}" >> "$GITHUB_OUTPUT"
- name: Move to Code Freeze
id: update
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
- name: Complete release
id: complete
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: update
stage: Code Freeze
version: ${{ steps.version.outputs.version }}
timeout: 300
command: complete
version: ${{ github.event.release.tag_name }}
- name: Print release URL
if: steps.complete.outputs.release-url
run: echo "Completed $RELEASE_URL"
env:
RELEASE_URL: ${{ steps.complete.outputs.release-url }}
+1 -1
View File
@@ -28,7 +28,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+5 -5
View File
@@ -39,7 +39,7 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -184,7 +184,7 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -228,7 +228,7 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -14,7 +14,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+114 -81
View File
@@ -9,7 +9,6 @@ on:
options:
- mainline
- stable
- rc
release_notes:
description: Release notes for the publishing the release. This is required to create a release.
dry_run:
@@ -81,7 +80,7 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -120,19 +119,9 @@ jobs:
exit 1
fi
# Derive the release branch from the version tag.
# Standard: 2.10.2 -> release/2.10
# RC: 2.32.0-rc.0 -> release/2.32-rc.0
# 2.10.2 -> release/2.10
version="$(./scripts/version.sh)"
if [[ "$version" == *-rc.* ]]; then
# Extract major.minor and rc suffix from e.g. 2.32.0-rc.0
base_version="${version%%-rc.*}" # 2.32.0
major_minor="${base_version%.*}" # 2.32
rc_suffix="${version##*-rc.}" # 0
release_branch="release/${major_minor}-rc.${rc_suffix}"
else
release_branch=release/${version%.*}
fi
release_branch=release/${version%.*}
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
if [[ -z "${branch_contains_tag}" ]]; then
echo "Ref tag must exist in a branch named ${release_branch} when creating a release, did you use scripts/release.sh?"
@@ -313,7 +302,6 @@ jobs:
# This uses OIDC authentication, so no auth variables are required.
- name: Build base Docker image via depot.dev
id: build_base_image
if: steps.image-base-tag.outputs.tag != ''
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
with:
@@ -361,14 +349,48 @@ jobs:
env:
IMAGE_TAG: ${{ steps.image-base-tag.outputs.tag }}
# GitHub attestation provides SLSA provenance for Docker images, establishing a verifiable
# record that these images were built in GitHub Actions with specific inputs and environment.
# This complements our existing cosign attestations (which focus on SBOMs) by adding
# GitHub-specific build provenance to enhance our supply chain security.
#
# TODO: Consider refactoring these attestation steps to use a matrix strategy or composite action
# to reduce duplication while maintaining the required functionality for each distinct image tag.
- name: GitHub Attestation for Base Docker image
id: attest_base
if: ${{ !inputs.dry_run && steps.build_base_image.outputs.digest != '' }}
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
with:
subject-name: ghcr.io/coder/coder-base
subject-digest: ${{ steps.build_base_image.outputs.digest }}
subject-name: ${{ steps.image-base-tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
predicate: |
{
"buildType": "https://github.com/actions/runner-images/",
"builder": {
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
},
"invocation": {
"configSource": {
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
"digest": {
"sha1": "${{ github.sha }}"
},
"entryPoint": ".github/workflows/release.yaml"
},
"environment": {
"github_workflow": "${{ github.workflow }}",
"github_run_id": "${{ github.run_id }}"
}
},
"metadata": {
"buildInvocationID": "${{ github.run_id }}",
"completeness": {
"environment": true,
"materials": true
}
}
}
push-to-registry: true
- name: Build Linux Docker images
@@ -391,6 +413,7 @@ jobs:
# being pushed so will automatically push them.
make push/build/coder_"$version"_linux.tag
# Save multiarch image tag for attestation
multiarch_image="$(./scripts/image_tag.sh)"
echo "multiarch_image=${multiarch_image}" >> "$GITHUB_OUTPUT"
@@ -401,14 +424,12 @@ jobs:
# version in the repo, also create a multi-arch image as ":latest" and
# push it
if [[ "$(git tag | grep '^v' | grep -vE '(rc|dev|-|\+|\/)' | sort -r --version-sort | head -n1)" == "v$(./scripts/version.sh)" ]]; then
latest_target="$(./scripts/image_tag.sh --version latest)"
# shellcheck disable=SC2046
./scripts/build_docker_multiarch.sh \
--push \
--target "${latest_target}" \
--target "$(./scripts/image_tag.sh --version latest)" \
$(cat build/coder_"$version"_linux_{amd64,arm64,armv7}.tag)
echo "created_latest_tag=true" >> "$GITHUB_OUTPUT"
echo "latest_target=${latest_target}" >> "$GITHUB_OUTPUT"
else
echo "created_latest_tag=false" >> "$GITHUB_OUTPUT"
fi
@@ -429,6 +450,7 @@ jobs:
echo "Generating SBOM for multi-arch image: ${MULTIARCH_IMAGE}"
syft "${MULTIARCH_IMAGE}" -o spdx-json > "coder_${VERSION}_sbom.spdx.json"
# Attest SBOM to multi-arch image
echo "Attesting SBOM to multi-arch image: ${MULTIARCH_IMAGE}"
cosign clean --force=true "${MULTIARCH_IMAGE}"
cosign attest --type spdxjson \
@@ -450,42 +472,85 @@ jobs:
"${latest_tag}"
fi
- name: Resolve Docker image digests for attestation
id: docker_digests
if: ${{ !inputs.dry_run }}
continue-on-error: true
env:
MULTIARCH_IMAGE: ${{ steps.build_docker.outputs.multiarch_image }}
LATEST_TARGET: ${{ steps.build_docker.outputs.latest_target }}
run: |
set -euxo pipefail
if [[ -n "${MULTIARCH_IMAGE}" ]]; then
multiarch_digest=$(docker buildx imagetools inspect --raw "${MULTIARCH_IMAGE}" | sha256sum | awk '{print "sha256:"$1}')
echo "multiarch_digest=${multiarch_digest}" >> "$GITHUB_OUTPUT"
fi
if [[ -n "${LATEST_TARGET}" ]]; then
latest_digest=$(docker buildx imagetools inspect --raw "${LATEST_TARGET}" | sha256sum | awk '{print "sha256:"$1}')
echo "latest_digest=${latest_digest}" >> "$GITHUB_OUTPUT"
fi
- name: GitHub Attestation for Docker image
id: attest_main
if: ${{ !inputs.dry_run && steps.docker_digests.outputs.multiarch_digest != '' }}
if: ${{ !inputs.dry_run }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
with:
subject-name: ghcr.io/coder/coder
subject-digest: ${{ steps.docker_digests.outputs.multiarch_digest }}
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
predicate-type: "https://slsa.dev/provenance/v1"
predicate: |
{
"buildType": "https://github.com/actions/runner-images/",
"builder": {
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
},
"invocation": {
"configSource": {
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
"digest": {
"sha1": "${{ github.sha }}"
},
"entryPoint": ".github/workflows/release.yaml"
},
"environment": {
"github_workflow": "${{ github.workflow }}",
"github_run_id": "${{ github.run_id }}"
}
},
"metadata": {
"buildInvocationID": "${{ github.run_id }}",
"completeness": {
"environment": true,
"materials": true
}
}
}
push-to-registry: true
# Get the latest tag name for attestation
- name: Get latest tag name
id: latest_tag
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
run: echo "tag=$(./scripts/image_tag.sh --version latest)" >> "$GITHUB_OUTPUT"
# If this is the highest version according to semver, also attest the "latest" tag
- name: GitHub Attestation for "latest" Docker image
id: attest_latest
if: ${{ !inputs.dry_run && steps.docker_digests.outputs.latest_digest != '' }}
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
with:
subject-name: ghcr.io/coder/coder
subject-digest: ${{ steps.docker_digests.outputs.latest_digest }}
subject-name: ${{ steps.latest_tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
predicate: |
{
"buildType": "https://github.com/actions/runner-images/",
"builder": {
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
},
"invocation": {
"configSource": {
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
"digest": {
"sha1": "${{ github.sha }}"
},
"entryPoint": ".github/workflows/release.yaml"
},
"environment": {
"github_workflow": "${{ github.workflow }}",
"github_run_id": "${{ github.run_id }}"
}
},
"metadata": {
"buildInvocationID": "${{ github.run_id }}",
"completeness": {
"environment": true,
"materials": true
}
}
}
push-to-registry: true
# Report attestation failures but don't fail the workflow
@@ -542,9 +607,6 @@ jobs:
if [[ $CODER_RELEASE_CHANNEL == "stable" ]]; then
publish_args+=(--stable)
fi
if [[ $CODER_RELEASE_CHANNEL == "rc" ]]; then
publish_args+=(--rc)
fi
if [[ $CODER_DRY_RUN == *t* ]]; then
publish_args+=(--dry-run)
fi
@@ -577,35 +639,6 @@ jobs:
VERSION: ${{ steps.version.outputs.version }}
CREATED_LATEST_TAG: ${{ steps.build_docker.outputs.created_latest_tag }}
# Mark the Linear release as shipped.
- name: Extract Linear release version
if: ${{ !inputs.dry_run }}
id: linear_version
run: |
# Skip RC releases — they must not complete the Linear release.
if [[ "$VERSION" == *-rc* ]]; then
echo "RC release (${VERSION}), skipping Linear release completion."
echo "skip=true" >> "$GITHUB_OUTPUT"
exit 0
fi
# Strip patch to get the Linear release version (e.g. 2.32.0 -> 2.32).
linear_version=$(echo "$VERSION" | cut -d. -f1,2)
echo "version=$linear_version" >> "$GITHUB_OUTPUT"
echo "skip=false" >> "$GITHUB_OUTPUT"
echo "Completing Linear release ${linear_version}"
env:
VERSION: ${{ steps.version.outputs.version }}
- name: Complete Linear release
if: ${{ !inputs.dry_run && steps.linear_version.outputs.skip != 'true' }}
continue-on-error: true
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: complete
version: ${{ steps.linear_version.outputs.version }}
timeout: 300
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
with:
@@ -657,7 +690,7 @@ jobs:
retention-days: 7
- name: Send repository-dispatch event
if: ${{ !inputs.dry_run && inputs.release_channel != 'rc' }}
if: ${{ !inputs.dry_run }}
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
with:
token: ${{ secrets.CDRCI_GITHUB_TOKEN }}
@@ -673,7 +706,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -745,11 +778,11 @@ jobs:
name: Publish to winget-pkgs
runs-on: windows-latest
needs: release
if: ${{ !inputs.dry_run && inputs.release_channel != 'rc' }}
if: ${{ !inputs.dry_run }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+2 -2
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -47,6 +47,6 @@ jobs:
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
with:
sarif_file: results.sarif
-354
View File
@@ -1,354 +0,0 @@
name: security-backport
on:
pull_request_target:
types:
- labeled
- unlabeled
- closed
workflow_dispatch:
inputs:
pull_request:
description: Pull request number to backport.
required: true
type: string
permissions:
contents: write
pull-requests: write
issues: write
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || inputs.pull_request }}
cancel-in-progress: false
env:
LATEST_BRANCH: release/2.31
STABLE_BRANCH: release/2.30
STABLE_1_BRANCH: release/2.29
jobs:
label-policy:
if: github.event_name == 'pull_request_target'
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Apply security backport label policy
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
pr_json="$(gh pr view "${pr_number}" --json number,title,url,baseRefName,labels)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import subprocess
import sys
pr = json.loads(os.environ["PR_JSON"])
pr_number = pr["number"]
labels = [label["name"] for label in pr.get("labels", [])]
def has(label: str) -> bool:
return label in labels
def ensure_label(label: str) -> None:
if not has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--add-label", label],
check=False,
)
def remove_label(label: str) -> None:
if has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--remove-label", label],
check=False,
)
def comment(body: str) -> None:
subprocess.run(
["gh", "pr", "comment", str(pr_number), "--body", body],
check=True,
)
if not has("security:patch"):
remove_label("status:needs-severity")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if has(label)
]
if len(severity_labels) == 0:
ensure_label("status:needs-severity")
comment(
"This PR is labeled `security:patch` but is missing a severity "
"label. Add one of `severity:medium`, `severity:high`, or "
"`severity:critical` before backport automation can proceed."
)
sys.exit(0)
if len(severity_labels) > 1:
comment(
"This PR has multiple severity labels. Keep exactly one of "
"`severity:medium`, `severity:high`, or `severity:critical`."
)
sys.exit(1)
remove_label("status:needs-severity")
target_labels = [
label
for label in ("backport:stable", "backport:stable-1")
if has(label)
]
has_none = has("backport:none")
if has_none and target_labels:
comment(
"`backport:none` cannot be combined with other backport labels. "
"Remove `backport:none` or remove the explicit backport targets."
)
sys.exit(1)
if not has_none and not target_labels:
ensure_label("backport:stable")
ensure_label("backport:stable-1")
comment(
"Applied default backport labels `backport:stable` and "
"`backport:stable-1` for a qualifying security patch."
)
PY
backport:
if: >
github.event_name == 'workflow_dispatch' ||
(
github.event_name == 'pull_request_target' &&
github.event.pull_request.merged == true
)
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Resolve PR metadata
id: metadata
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
INPUT_PR_NUMBER: ${{ inputs.pull_request }}
LATEST_BRANCH: ${{ env.LATEST_BRANCH }}
STABLE_BRANCH: ${{ env.STABLE_BRANCH }}
STABLE_1_BRANCH: ${{ env.STABLE_1_BRANCH }}
run: |
set -euo pipefail
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
pr_number="${INPUT_PR_NUMBER}"
else
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
fi
case "${pr_number}" in
''|*[!0-9]*)
echo "A valid pull request number is required."
exit 1
;;
esac
pr_json="$(gh pr view "${pr_number}" --json number,title,url,mergeCommit,baseRefName,labels,mergedAt,author)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import sys
pr = json.loads(os.environ["PR_JSON"])
github_output = os.environ["GITHUB_OUTPUT"]
labels = [label["name"] for label in pr.get("labels", [])]
if "security:patch" not in labels:
print("Not a security patch PR; skipping.")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if label in labels
]
if len(severity_labels) != 1:
raise SystemExit(
"Merged security patch PR must have exactly one severity label."
)
if not pr.get("mergedAt"):
raise SystemExit(f"PR #{pr['number']} is not merged.")
if "backport:none" in labels:
target_pairs = []
else:
mapping = {
"backport:stable": os.environ["STABLE_BRANCH"],
"backport:stable-1": os.environ["STABLE_1_BRANCH"],
}
target_pairs = []
for label_name, branch in mapping.items():
if label_name in labels and branch and branch != pr["baseRefName"]:
target_pairs.append({"label": label_name, "branch": branch})
with open(github_output, "a", encoding="utf-8") as f:
f.write(f"pr_number={pr['number']}\n")
f.write(f"merge_sha={pr['mergeCommit']['oid']}\n")
f.write(f"title={pr['title']}\n")
f.write(f"url={pr['url']}\n")
f.write(f"author={pr['author']['login']}\n")
f.write(f"severity_label={severity_labels[0]}\n")
f.write(f"target_pairs={json.dumps(target_pairs)}\n")
PY
- name: Backport to release branches
if: ${{ steps.metadata.outputs.target_pairs != '[]' }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ steps.metadata.outputs.pr_number }}
MERGE_SHA: ${{ steps.metadata.outputs.merge_sha }}
PR_TITLE: ${{ steps.metadata.outputs.title }}
PR_URL: ${{ steps.metadata.outputs.url }}
PR_AUTHOR: ${{ steps.metadata.outputs.author }}
SEVERITY_LABEL: ${{ steps.metadata.outputs.severity_label }}
TARGET_PAIRS: ${{ steps.metadata.outputs.target_pairs }}
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
git fetch origin --prune
merge_parent_count="$(git rev-list --parents -n 1 "${MERGE_SHA}" | awk '{print NF-1}')"
failures=()
successes=()
while IFS=$'\t' read -r backport_label target_branch; do
[ -n "${target_branch}" ] || continue
safe_branch_name="${target_branch//\//-}"
head_branch="backport/${safe_branch_name}/pr-${PR_NUMBER}"
existing_pr="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state all \
--json number,url \
--jq '.[0]')"
if [ -n "${existing_pr}" ] && [ "${existing_pr}" != "null" ]; then
pr_url="$(printf '%s' "${existing_pr}" | jq -r '.url')"
successes+=("${target_branch}:existing:${pr_url}")
continue
fi
git checkout -B "${head_branch}" "origin/${target_branch}"
if [ "${merge_parent_count}" -gt 1 ]; then
cherry_pick_args=(-m 1 "${MERGE_SHA}")
else
cherry_pick_args=("${MERGE_SHA}")
fi
if ! git cherry-pick -x "${cherry_pick_args[@]}"; then
git cherry-pick --abort || true
gh pr edit "${PR_NUMBER}" --add-label "backport:conflict" || true
gh pr comment "${PR_NUMBER}" --body \
"Automatic backport to \`${target_branch}\` conflicted. The original author or release manager should resolve it manually."
failures+=("${target_branch}:cherry-pick failed")
continue
fi
git push --force-with-lease origin "${head_branch}"
body_file="$(mktemp)"
printf '%s\n' \
"Automated backport of [#${PR_NUMBER}](${PR_URL})." \
"" \
"- Source PR: #${PR_NUMBER}" \
"- Source commit: ${MERGE_SHA}" \
"- Target branch: ${target_branch}" \
"- Severity: ${SEVERITY_LABEL}" \
> "${body_file}"
pr_url="$(gh pr create \
--base "${target_branch}" \
--head "${head_branch}" \
--title "${PR_TITLE} (backport to ${target_branch})" \
--body-file "${body_file}")"
backport_pr_number="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state open \
--json number \
--jq '.[0].number')"
gh pr edit "${backport_pr_number}" \
--add-label "security:patch" \
--add-label "${SEVERITY_LABEL}" \
--add-label "${backport_label}" || true
successes+=("${target_branch}:created:${pr_url}")
done < <(
python3 - <<'PY'
import json
import os
for pair in json.loads(os.environ["TARGET_PAIRS"]):
print(f"{pair['label']}\t{pair['branch']}")
PY
)
summary_file="$(mktemp)"
{
echo "## Security backport summary"
echo
if [ "${#successes[@]}" -gt 0 ]; then
echo "### Created or existing"
for entry in "${successes[@]}"; do
echo "- ${entry}"
done
echo
fi
if [ "${#failures[@]}" -gt 0 ]; then
echo "### Failures"
for entry in "${failures[@]}"; do
echo "- ${entry}"
done
fi
} | tee -a "${GITHUB_STEP_SUMMARY}" > "${summary_file}"
gh pr comment "${PR_NUMBER}" --body-file "${summary_file}"
if [ "${#failures[@]}" -gt 0 ]; then
printf 'Backport failures:\n%s\n' "${failures[@]}" >&2
exit 1
fi
-214
View File
@@ -1,214 +0,0 @@
name: security-patch-prs
on:
workflow_dispatch:
schedule:
- cron: "0 3 * * 1-5"
permissions:
contents: write
pull-requests: write
jobs:
patch:
strategy:
fail-fast: false
matrix:
lane:
- gomod
- terraform
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Patch Go dependencies
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go get -u=patch ./...
go mod tidy
# Guardrail: do not auto-edit replace directives.
if git diff --unified=0 -- go.mod | grep -E '^[+-]replace '; then
echo "Refusing to auto-edit go.mod replace directives"
exit 1
fi
# Guardrail: only go.mod / go.sum may change.
extra="$(git diff --name-only | grep -Ev '^(go\.mod|go\.sum)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Patch bundled Terraform
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
current="$(
grep -oE 'NewVersion\("[0-9]+\.[0-9]+\.[0-9]+"\)' \
provisioner/terraform/install.go \
| head -1 \
| grep -oE '[0-9]+\.[0-9]+\.[0-9]+'
)"
series="$(echo "$current" | cut -d. -f1,2)"
latest="$(
curl -fsSL https://releases.hashicorp.com/terraform/index.json \
| jq -r --arg series "$series" '
.versions
| keys[]
| select(startswith($series + "."))
' \
| sort -V \
| tail -1
)"
test -n "$latest"
[ "$latest" != "$current" ] || exit 0
CURRENT_TERRAFORM_VERSION="$current" \
LATEST_TERRAFORM_VERSION="$latest" \
python3 - <<'PY'
from pathlib import Path
import os
current = os.environ["CURRENT_TERRAFORM_VERSION"]
latest = os.environ["LATEST_TERRAFORM_VERSION"]
updates = {
"scripts/Dockerfile.base": (
f"terraform/{current}/",
f"terraform/{latest}/",
),
"provisioner/terraform/install.go": (
f'NewVersion("{current}")',
f'NewVersion("{latest}")',
),
"install.sh": (
f'TERRAFORM_VERSION="{current}"',
f'TERRAFORM_VERSION="{latest}"',
),
}
for path_str, (before, after) in updates.items():
path = Path(path_str)
content = path.read_text()
if before not in content:
raise SystemExit(f"did not find expected text in {path_str}: {before}")
path.write_text(content.replace(before, after))
PY
# Guardrail: only the Terraform-version files may change.
extra="$(git diff --name-only | grep -Ev '^(scripts/Dockerfile.base|provisioner/terraform/install.go|install.sh)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Validate Go dependency patch
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go test ./...
- name: Validate Terraform patch
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
go test ./provisioner/terraform/...
docker build -f scripts/Dockerfile.base .
- name: Skip PR creation when there are no changes
id: changes
shell: bash
run: |
set -euo pipefail
if git diff --quiet; then
echo "has_changes=false" >> "${GITHUB_OUTPUT}"
else
echo "has_changes=true" >> "${GITHUB_OUTPUT}"
fi
- name: Commit changes
if: steps.changes.outputs.has_changes == 'true'
shell: bash
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git checkout -B "secpatch/${{ matrix.lane }}"
git add -A
git commit -m "security: patch ${{ matrix.lane }}"
- name: Push branch
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
git push --force-with-lease \
"https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git" \
"HEAD:refs/heads/secpatch/${{ matrix.lane }}"
- name: Create or update PR
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
branch="secpatch/${{ matrix.lane }}"
title="security: patch ${{ matrix.lane }}"
body="$(cat <<'EOF'
Automated security patch PR for `${{ matrix.lane }}`.
Scope:
- gomod: patch-level Go dependency updates only
- terraform: bundled Terraform patch updates only
Guardrails:
- no application-code edits
- no auto-editing of go.mod replace directives
- CI must pass
EOF
)"
existing_pr="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
if [[ -n "${existing_pr}" ]]; then
gh pr edit "${existing_pr}" \
--title "${title}" \
--body "${body}"
pr_number="${existing_pr}"
else
gh pr create \
--base main \
--head "${branch}" \
--title "${title}" \
--body "${body}"
pr_number="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
fi
for label in security dependencies automated-pr; do
if gh label list --json name --jq '.[].name' | grep -Fxq "${label}"; then
gh pr edit "${pr_number}" --add-label "${label}"
fi
done
+3 -3
View File
@@ -27,7 +27,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -40,7 +40,7 @@ jobs:
uses: ./.github/actions/setup-go
- name: Initialize CodeQL
uses: github/codeql-action/init@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
uses: github/codeql-action/init@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
with:
languages: go, javascript
@@ -50,7 +50,7 @@ jobs:
rm Makefile
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
uses: github/codeql-action/analyze@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
- name: Send Slack notification on failure
if: ${{ failure() }}
+5 -5
View File
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -96,7 +96,7 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -120,12 +120,12 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Delete PR Cleanup workflow runs
uses: Mattraks/delete-workflow-runs@b3018382ca039b53d238908238bd35d1fb14f8ee # v2.1.0
uses: Mattraks/delete-workflow-runs@5bf9a1dac5c4d041c029f0a8370ddf0c5cb5aeb7 # v2.1.0
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
@@ -134,7 +134,7 @@ jobs:
delete_workflow_pattern: pr-cleanup.yaml
- name: Delete PR Deploy workflow skipped runs
uses: Mattraks/delete-workflow-runs@b3018382ca039b53d238908238bd35d1fb14f8ee # v2.1.0
uses: Mattraks/delete-workflow-runs@5bf9a1dac5c4d041c029f0a8370ddf0c5cb5aeb7 # v2.1.0
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
+2 -10
View File
@@ -21,7 +21,7 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -46,16 +46,8 @@ jobs:
echo " replacement: \"https://github.com/coder/coder/tree/${HEAD_SHA}/\""
} >> .github/.linkspector.yml
# TODO: Remove this workaround once action-linkspector sets
# package-manager-cache: false in its internal setup-node step.
# See: https://github.com/UmbrellaDocs/action-linkspector/issues/54
- name: Enable corepack and create pnpm store
run: |
corepack enable pnpm
mkdir -p "$(pnpm store path --silent)"
- name: Check Markdown links
uses: umbrelladocs/action-linkspector@37c85bcde51b30bf929936502bac6bfb7e8f0a4d # v1.4.1
uses: umbrelladocs/action-linkspector@652f85bc57bb1e7d4327260decc10aa68f7694c3 # v1.4.0
id: markdown-link-check
# checks all markdown files from /docs including all subfolders
with:
-1
View File
@@ -54,7 +54,6 @@ site/stats/
*.tfstate.backup
*.tfplan
*.lock.hcl
!provisioner/terraform/testdata/resources/.terraform.lock.hcl
.terraform/
!coderd/testdata/parameters/modules/.terraform/
!provisioner/terraform/testdata/modules-source-caching/.terraform/
-3
View File
@@ -110,9 +110,6 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
- For experimental or unstable API paths, skip public doc generation with
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
keeps them out of the published API reference until they stabilize.
- Experimental chat endpoints in `coderd/exp_chats.go` omit swagger
annotations entirely. Do not add `@Summary`, `@Router`, or other
swagger comments to handlers in that file.
### Database Query Naming
+2 -13
View File
@@ -988,7 +988,6 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g
codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go
go generate ./codersdk/workspacesdk/agentconnmock/
./scripts/format_go_file.sh "$@"
touch "$@"
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
@@ -1261,21 +1260,11 @@ provisioner/terraform/testdata/.gen-golden: $(wildcard provisioner/terraform/tes
touch "$@"
provisioner/terraform/testdata/version:
@tf_match=true; \
if [[ "$$(cat provisioner/terraform/testdata/version.txt)" != \
"$$(terraform version -json | jq -r '.terraform_version')" ]]; then \
tf_match=false; \
fi; \
if ! $$tf_match || \
! ./provisioner/terraform/testdata/generate.sh --check; then \
./provisioner/terraform/testdata/generate.sh; \
if [[ "$(shell cat provisioner/terraform/testdata/version.txt)" != "$(shell terraform version -json | jq -r '.terraform_version')" ]]; then
./provisioner/terraform/testdata/generate.sh
fi
.PHONY: provisioner/terraform/testdata/version
update-terraform-testdata:
./provisioner/terraform/testdata/generate.sh --upgrade
.PHONY: update-terraform-testdata
# Set the retry flags if TEST_RETRIES is set
ifdef TEST_RETRIES
GOTESTSUM_RETRY_FLAGS := --rerun-fails=$(TEST_RETRIES)
+6 -31
View File
@@ -38,7 +38,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/clistat"
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentcontextconfig"
"github.com/coder/coder/v2/agent/agentdesktop"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentfiles"
"github.com/coder/coder/v2/agent/agentgit"
@@ -50,8 +50,6 @@ import (
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
"github.com/coder/coder/v2/agent/reconnectingpty"
"github.com/coder/coder/v2/agent/x/agentdesktop"
"github.com/coder/coder/v2/agent/x/agentmcp"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli/gitauth"
"github.com/coder/coder/v2/coderd/database/dbtime"
@@ -309,13 +307,10 @@ type agent struct {
containerAPI *agentcontainers.API
gitAPIOptions []agentgit.Option
filesAPI *agentfiles.API
gitAPI *agentgit.API
processAPI *agentproc.API
desktopAPI *agentdesktop.API
mcpManager *agentmcp.Manager
mcpAPI *agentmcp.API
contextConfigAPI *agentcontextconfig.API
filesAPI *agentfiles.API
gitAPI *agentgit.API
processAPI *agentproc.API
desktopAPI *agentdesktop.API
socketServerEnabled bool
socketPath string
@@ -398,17 +393,9 @@ func (a *agent) init() {
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
desktop := agentdesktop.NewPortableDesktop(
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(), nil,
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp"))
a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager)
a.contextConfigAPI = agentcontextconfig.NewAPI(func() string {
if m := a.manifest.Load(); m != nil {
return m.Directory
}
return ""
})
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
@@ -1361,14 +1348,6 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
}
a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur)
a.scriptRunner.StartCron()
// Connect to workspace MCP servers after the
// lifecycle transition to avoid delaying Ready.
// This runs inside the tracked goroutine so it
// is properly awaited on shutdown.
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.MCPConfigFiles()); mcpErr != nil {
a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr))
}
})
if err != nil {
return xerrors.Errorf("track conn goroutine: %w", err)
@@ -2091,10 +2070,6 @@ func (a *agent) Close() error {
a.logger.Error(a.hardCtx, "desktop API close", slog.Error(err))
}
if err := a.mcpManager.Close(); err != nil {
a.logger.Error(a.hardCtx, "mcp manager close", slog.Error(err))
}
if a.boundaryLogProxy != nil {
err = a.boundaryLogProxy.Close()
if err != nil {
-52
View File
@@ -1,8 +1,6 @@
package agent
import (
"path/filepath"
"runtime"
"testing"
"github.com/google/uuid"
@@ -10,22 +8,10 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentcontextconfig"
"github.com/coder/coder/v2/agent/proto"
agentsdk "github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
)
// platformAbsPath constructs an absolute path that is valid
// on the current platform. On Windows, paths must include a
// drive letter to be considered absolute.
func platformAbsPath(parts ...string) string {
if runtime.GOOS == "windows" {
return `C:\` + filepath.Join(parts...)
}
return "/" + filepath.Join(parts...)
}
// TestReportConnectionEmpty tests that reportConnection() doesn't choke if given an empty IP string, which is what we
// send if we cannot get the remote address.
func TestReportConnectionEmpty(t *testing.T) {
@@ -56,41 +42,3 @@ func TestReportConnectionEmpty(t *testing.T) {
require.Equal(t, proto.Connection_DISCONNECT, req1.GetConnection().GetAction())
require.Equal(t, "because", req1.GetConnection().GetReason())
}
func TestContextConfigAPI_InitOnce(t *testing.T) {
// Not parallel: uses t.Setenv to clear env vars.
// Clear env vars so defaults are used and the test is
// hermetic regardless of the surrounding environment.
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
// After the fix, contextConfigAPI is set once in init() and
// never reassigned. Config() evaluates lazily via the
// manifest, so there is no concurrent write to race with.
dir1 := platformAbsPath("dir1")
dir2 := platformAbsPath("dir2")
a := &agent{}
a.manifest.Store(&agentsdk.Manifest{Directory: dir1})
a.contextConfigAPI = agentcontextconfig.NewAPI(func() string {
if m := a.manifest.Load(); m != nil {
return m.Directory
}
return ""
})
mcpFiles1 := a.contextConfigAPI.MCPConfigFiles()
require.NotEmpty(t, mcpFiles1)
require.Contains(t, mcpFiles1[0], dir1)
// Simulate manifest update on reconnection -- no field
// reassignment needed, the lazy closure picks it up.
a.manifest.Store(&agentsdk.Manifest{Directory: dir2})
mcpFiles2 := a.contextConfigAPI.MCPConfigFiles()
require.NotEmpty(t, mcpFiles2)
require.Contains(t, mcpFiles2[0], dir2)
}
+8 -15
View File
@@ -3007,7 +3007,7 @@ func TestAgent_Speedtest(t *testing.T) {
func TestAgent_Reconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
// After the agent is disconnected from a coordinator, it's supposed
// to reconnect!
@@ -3020,8 +3020,7 @@ func TestAgent_Reconnect(t *testing.T) {
logger,
agentID,
agentsdk.Manifest{
DERPMap: derpMap,
Directory: "/test/workspace",
DERPMap: derpMap,
},
statsCh,
fCoordinator,
@@ -3034,19 +3033,13 @@ func TestAgent_Reconnect(t *testing.T) {
})
defer closer.Close()
// Each iteration forces the agent to reconnect by closing
// the current coordinate call while the tracked HTTP server
// goroutine (from connection 1's createTailnet) is still
// alive, widening the race window.
const reconnections = 5
for i := range reconnections {
call := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
require.Equal(t, i+1, client.GetNumRefreshTokenCalls())
close(call.Resps) // hang up — triggers reconnect
}
// Verify final reconnect succeeds.
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
require.Equal(t, client.GetNumRefreshTokenCalls(), 1)
close(call1.Resps) // hang up
// expect reconnect
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
require.Equal(t, reconnections+1, client.GetNumRefreshTokenCalls())
// Check that the agent refreshes the token when it reconnects.
require.Equal(t, client.GetNumRefreshTokenCalls(), 2)
closer.Close()
}
@@ -159,6 +159,7 @@ func TestConvertDockerVolume(t *testing.T) {
func TestConvertDockerInspect(t *testing.T) {
t.Parallel()
//nolint:paralleltest // variable recapture no longer required
for _, tt := range []struct {
name string
expect []codersdk.WorkspaceAgentContainer
@@ -387,6 +388,7 @@ func TestConvertDockerInspect(t *testing.T) {
},
},
} {
// nolint:paralleltest // variable recapture no longer required
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
bs, err := os.ReadFile(filepath.Join("testdata", tt.name, "docker_inspect.json"))
+2
View File
@@ -166,6 +166,7 @@ func TestDockerEnvInfoer(t *testing.T) {
pool, err := dockertest.NewPool("")
require.NoError(t, err, "Could not connect to docker")
// nolint:paralleltest // variable recapture no longer required
for idx, tt := range []struct {
image string
labels map[string]string
@@ -222,6 +223,7 @@ func TestDockerEnvInfoer(t *testing.T) {
expectedUserShell: "/bin/bash",
},
} {
//nolint:paralleltest // variable recapture no longer required
t.Run(fmt.Sprintf("#%d", idx), func(t *testing.T) {
// Start a container with the given image
// and environment variables
-313
View File
@@ -1,313 +0,0 @@
package agentcontextconfig
import (
"cmp"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// Env var names for context configuration. Prefixed with EXP_
// to indicate these are experimental and may change.
const (
EnvInstructionsDirs = "CODER_AGENT_EXP_INSTRUCTIONS_DIRS"
EnvInstructionsFile = "CODER_AGENT_EXP_INSTRUCTIONS_FILE"
EnvSkillsDirs = "CODER_AGENT_EXP_SKILLS_DIRS"
EnvSkillMetaFile = "CODER_AGENT_EXP_SKILL_META_FILE"
EnvMCPConfigFiles = "CODER_AGENT_EXP_MCP_CONFIG_FILES"
)
const (
maxInstructionFileBytes = 64 * 1024
maxSkillMetaBytes = 64 * 1024
)
// markdownCommentPattern strips HTML comments from instruction
// file content for security (prevents hidden prompt injection).
var markdownCommentPattern = regexp.MustCompile(`<!--[\s\S]*?-->`)
// invisibleRunePattern strips invisible Unicode characters that
// could be used for prompt injection.
//
//nolint:gocritic // Non-ASCII char ranges are intentional for invisible Unicode stripping.
var invisibleRunePattern = regexp.MustCompile(
"[\u00ad\u034f\u061c\u070f" +
"\u115f\u1160\u17b4\u17b5" +
"\u180b-\u180f" +
"\u200b\u200d\u200e\u200f" +
"\u202a-\u202e" +
"\u2060-\u206f" +
"\u3164" +
"\ufe00-\ufe0f" +
"\ufeff" +
"\uffa0" +
"\ufff0-\ufff8]",
)
// skillNamePattern validates kebab-case skill names.
var skillNamePattern = regexp.MustCompile(
`^[a-z0-9]+(-[a-z0-9]+)*$`,
)
// Default values for agent-internal configuration. These are
// used when the corresponding env vars are unset.
const (
DefaultInstructionsDir = "~/.coder"
DefaultInstructionsFile = "AGENTS.md"
DefaultSkillsDir = ".agents/skills"
DefaultSkillMetaFile = "SKILL.md"
DefaultMCPConfigFile = ".mcp.json"
)
// API exposes the resolved context configuration through the
// agent's HTTP API.
type API struct {
workingDir func() string
}
// NewAPI accepts a closure that returns the working directory.
// The directory is evaluated lazily on each call to Config(),
// so the caller can update it after construction.
func NewAPI(workingDir func() string) *API {
if workingDir == nil {
workingDir = func() string { return "" }
}
return &API{workingDir: workingDir}
}
// Config reads env vars, resolves paths, reads instruction files,
// and discovers skills. Returns the HTTP response and the resolved
// MCP config file paths (used only agent-internally). Exported
// for use by tests.
func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
// TrimSpace all env vars before cmp.Or so that a
// whitespace-only value falls through to the default
// consistently. ResolvePaths also trims each comma-
// separated entry, but without pre-trimming here a
// bare " " would bypass cmp.Or and produce nil.
instructionsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsDirs)), DefaultInstructionsDir)
instructionsFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsFile)), DefaultInstructionsFile)
skillsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillsDirs)), DefaultSkillsDir)
skillMetaFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillMetaFile)), DefaultSkillMetaFile)
mcpConfigFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvMCPConfigFiles)), DefaultMCPConfigFile)
resolvedInstructionsDirs := ResolvePaths(instructionsDir, workingDir)
resolvedSkillsDirs := ResolvePaths(skillsDir, workingDir)
// Read instruction files from each configured directory.
parts := readInstructionFiles(resolvedInstructionsDirs, instructionsFile)
// Also check the working directory for the instruction file,
// unless it was already covered by InstructionsDirs.
if workingDir != "" {
seenDirs := make(map[string]struct{}, len(resolvedInstructionsDirs))
for _, d := range resolvedInstructionsDirs {
seenDirs[d] = struct{}{}
}
if _, ok := seenDirs[workingDir]; !ok {
if entry, found := readInstructionFileFromDir(workingDir, instructionsFile); found {
parts = append(parts, entry)
}
}
}
// Discover skills from each configured skills directory.
skillParts := discoverSkills(resolvedSkillsDirs, skillMetaFile)
parts = append(parts, skillParts...)
// Guarantee non-nil slice to signal agent support.
if parts == nil {
parts = []codersdk.ChatMessagePart{}
}
return workspacesdk.ContextConfigResponse{
Parts: parts,
}, ResolvePaths(mcpConfigFile, workingDir)
}
// MCPConfigFiles returns the resolved MCP configuration file
// paths for the agent's MCP manager.
func (api *API) MCPConfigFiles() []string {
_, mcpFiles := Config(api.workingDir())
return mcpFiles
}
// Routes returns the HTTP handler for the context config
// endpoint.
func (api *API) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/", api.handleGet)
return r
}
func (api *API) handleGet(rw http.ResponseWriter, r *http.Request) {
response, _ := Config(api.workingDir())
httpapi.Write(r.Context(), rw, http.StatusOK, response)
}
// readInstructionFiles reads instruction files from each given
// directory. Missing directories are silently skipped. Duplicate
// directories are deduplicated.
func readInstructionFiles(dirs []string, fileName string) []codersdk.ChatMessagePart {
var parts []codersdk.ChatMessagePart
seen := make(map[string]struct{}, len(dirs))
for _, dir := range dirs {
if _, ok := seen[dir]; ok {
continue
}
seen[dir] = struct{}{}
if part, found := readInstructionFileFromDir(dir, fileName); found {
parts = append(parts, part)
}
}
return parts
}
// readInstructionFileFromDir scans a directory for a file matching
// fileName (case-insensitive) and reads its contents.
func readInstructionFileFromDir(dir, fileName string) (codersdk.ChatMessagePart, bool) {
dirEntries, err := os.ReadDir(dir)
if err != nil {
return codersdk.ChatMessagePart{}, false
}
for _, e := range dirEntries {
if e.IsDir() {
continue
}
if strings.EqualFold(strings.TrimSpace(e.Name()), fileName) {
filePath := filepath.Join(dir, e.Name())
content, truncated, ok := readAndSanitizeFile(filePath, maxInstructionFileBytes)
if !ok {
return codersdk.ChatMessagePart{}, false
}
if content == "" {
return codersdk.ChatMessagePart{}, false
}
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: filePath,
ContextFileContent: content,
ContextFileTruncated: truncated,
}, true
}
}
return codersdk.ChatMessagePart{}, false
}
// readAndSanitizeFile reads the file at path, capping the read
// at maxBytes to avoid unbounded memory allocation. It sanitizes
// the content (strips HTML comments and invisible Unicode) and
// returns the result. Returns false if the file cannot be read.
func readAndSanitizeFile(path string, maxBytes int64) (content string, truncated bool, ok bool) {
f, err := os.Open(path)
if err != nil {
return "", false, false
}
defer f.Close()
// Read at most maxBytes+1 to detect truncation without
// allocating the entire file into memory.
raw, err := io.ReadAll(io.LimitReader(f, maxBytes+1))
if err != nil {
return "", false, false
}
truncated = int64(len(raw)) > maxBytes
if truncated {
raw = raw[:maxBytes]
}
s := sanitizeInstructionMarkdown(string(raw))
if s == "" {
return "", truncated, true
}
return s, truncated, true
}
// sanitizeInstructionMarkdown strips HTML comments, invisible
// Unicode characters, and CRLF line endings from instruction
// file content.
func sanitizeInstructionMarkdown(content string) string {
content = strings.ReplaceAll(content, "\r\n", "\n")
content = strings.ReplaceAll(content, "\r", "\n")
content = markdownCommentPattern.ReplaceAllString(content, "")
content = invisibleRunePattern.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
// discoverSkills walks the given skills directories and returns
// metadata for every valid skill it finds. Body and supporting
// file lists are NOT included; chatd fetches those on demand
// via read_skill. Missing directories or individual errors are
// silently skipped.
func discoverSkills(skillsDirs []string, metaFile string) []codersdk.ChatMessagePart {
seen := make(map[string]struct{})
var parts []codersdk.ChatMessagePart
for _, skillsDir := range skillsDirs {
entries, err := os.ReadDir(skillsDir)
if err != nil {
continue
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
metaPath := filepath.Join(skillsDir, entry.Name(), metaFile)
f, err := os.Open(metaPath)
if err != nil {
continue
}
raw, err := io.ReadAll(io.LimitReader(f, maxSkillMetaBytes+1))
_ = f.Close()
if err != nil {
continue
}
if int64(len(raw)) > maxSkillMetaBytes {
raw = raw[:maxSkillMetaBytes]
}
name, description, _, err := workspacesdk.ParseSkillFrontmatter(string(raw))
if err != nil {
continue
}
// The directory name must match the declared name.
if name != entry.Name() {
continue
}
if !skillNamePattern.MatchString(name) {
continue
}
// First occurrence wins across directories.
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
skillDir := filepath.Join(skillsDir, entry.Name())
parts = append(parts, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: name,
SkillDescription: description,
SkillDir: skillDir,
ContextFileSkillMetaFile: metaFile,
})
}
}
return parts
}
-439
View File
@@ -1,439 +0,0 @@
package agentcontextconfig_test
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentcontextconfig"
"github.com/coder/coder/v2/codersdk"
)
// filterParts returns only the parts matching the given type.
func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartType) []codersdk.ChatMessagePart {
var out []codersdk.ChatMessagePart
for _, p := range parts {
if p.Type == t {
out = append(out, p)
}
}
return out
}
func TestConfig(t *testing.T) {
t.Run("Defaults", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
// Clear all env vars so defaults are used.
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := platformAbsPath("work")
cfg, mcpFiles := agentcontextconfig.Config(workDir)
// Parts is always non-nil.
require.NotNil(t, cfg.Parts)
// Default MCP config file is ".mcp.json" (relative),
// resolved against the working directory.
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
})
t.Run("CustomEnvVars", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
optInstructions := t.TempDir()
optSkills := t.TempDir()
optMCP := platformAbsPath("opt", "mcp.json")
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
// Create files matching the custom names so we can
// verify the env vars actually change lookup behavior.
require.NoError(t, os.WriteFile(filepath.Join(optInstructions, "CUSTOM.md"), []byte("custom instructions"), 0o600))
skillDir := filepath.Join(optSkills, "my-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "META.yaml"),
[]byte("---\nname: my-skill\ndescription: custom meta\n---\n"),
0o600,
))
workDir := platformAbsPath("work")
cfg, mcpFiles := agentcontextconfig.Config(workDir)
require.Equal(t, []string{optMCP}, mcpFiles)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "custom instructions", ctxFiles[0].ContextFileContent)
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, skillParts, 1)
require.Equal(t, "my-skill", skillParts[0].SkillName)
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
})
t.Run("WhitespaceInFileNames", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// Create a file matching the trimmed name.
require.NoError(t, os.WriteFile(filepath.Join(fakeHome, "CLAUDE.md"), []byte("hello"), 0o600))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
})
t.Run("CommaSeparatedDirs", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
a := t.TempDir()
b := t.TempDir()
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
// Put instruction files in both dirs.
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
require.NoError(t, os.WriteFile(filepath.Join(b, "AGENTS.md"), []byte("from b"), 0o600))
workDir := t.TempDir()
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 2)
require.Equal(t, "from a", ctxFiles[0].ContextFileContent)
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
})
t.Run("ReadsInstructionFiles", func(t *testing.T) {
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
// Create ~/.coder/AGENTS.md
coderDir := filepath.Join(fakeHome, ".coder")
require.NoError(t, os.MkdirAll(coderDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(coderDir, "AGENTS.md"),
[]byte("home instructions"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.NotNil(t, cfg.Parts)
require.Len(t, ctxFiles, 1)
require.Equal(t, "home instructions", ctxFiles[0].ContextFileContent)
require.Equal(t, filepath.Join(coderDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
require.False(t, ctxFiles[0].ContextFileTruncated)
})
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// Create AGENTS.md in the working directory.
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("project instructions"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
// Should find the working dir file (not in instruction dirs).
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.NotNil(t, cfg.Parts)
require.Len(t, ctxFiles, 1)
require.Equal(t, "project instructions", ctxFiles[0].ContextFileContent)
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
})
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
largeContent := strings.Repeat("a", 64*1024+100)
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.True(t, ctxFiles[0].ContextFileTruncated)
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
})
t.Run("SanitizesHTMLComments", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("visible\n<!-- hidden -->content"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
})
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// U+200B (zero-width space) should be stripped.
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("before\u200bafter"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
})
t.Run("NormalizesCRLF", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("line1\r\nline2\rline3"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
})
t.Run("DiscoversSkills", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir := filepath.Join(workDir, ".agents", "skills")
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir)
// Create a valid skill.
skillDir := filepath.Join(skillsDir, "my-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "SKILL.md"),
[]byte("---\nname: my-skill\ndescription: A test skill\n---\nSkill body"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, skillParts, 1)
require.Equal(t, "my-skill", skillParts[0].SkillName)
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
require.Equal(t, skillDir, skillParts[0].SkillDir)
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
})
t.Run("SkipsMissingDirs", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
cfg, _ := agentcontextconfig.Config(workDir)
// Non-nil empty slice (signals agent supports new format).
require.NotNil(t, cfg.Parts)
require.Empty(t, cfg.Parts)
})
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
optMCP := platformAbsPath("opt", "custom.json")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
workDir := t.TempDir()
_, mcpFiles := agentcontextconfig.Config(workDir)
require.Equal(t, []string{optMCP}, mcpFiles)
})
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir := filepath.Join(workDir, "skills")
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir)
// Skill name in frontmatter doesn't match directory name.
skillDir := filepath.Join(skillsDir, "wrong-dir-name")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "SKILL.md"),
[]byte("---\nname: actual-name\ndescription: mismatch\n---\n"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
require.Empty(t, skillParts)
})
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir1 := filepath.Join(workDir, "skills1")
skillsDir2 := filepath.Join(workDir, "skills2")
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir1+","+skillsDir2)
// Same skill name in both directories.
for _, dir := range []string{skillsDir1, skillsDir2} {
skillDir := filepath.Join(dir, "dup-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "SKILL.md"),
[]byte("---\nname: dup-skill\ndescription: from "+filepath.Base(dir)+"\n---\n"),
0o600,
))
}
cfg, _ := agentcontextconfig.Config(workDir)
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, skillParts, 1)
require.Equal(t, "from skills1", skillParts[0].SkillDescription)
})
}
func TestNewAPI_LazyDirectory(t *testing.T) {
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
dir := ""
api := agentcontextconfig.NewAPI(func() string { return dir })
// Before directory is set, MCP paths resolve to nothing.
mcpFiles := api.MCPConfigFiles()
require.Empty(t, mcpFiles)
// After setting the directory, MCPConfigFiles() picks it up.
dir = platformAbsPath("work")
mcpFiles = api.MCPConfigFiles()
require.NotEmpty(t, mcpFiles)
require.Equal(t, []string{filepath.Join(dir, ".mcp.json")}, mcpFiles)
}
-55
View File
@@ -1,55 +0,0 @@
package agentcontextconfig
import (
"os"
"path/filepath"
"strings"
)
// ResolvePath resolves a single path that may be absolute,
// home-relative (~/ or ~), or relative to the given base
// directory. Returns an absolute path. Empty input returns empty.
func ResolvePath(raw, baseDir string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
switch {
case raw == "~":
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return home
case strings.HasPrefix(raw, "~/"):
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, raw[2:])
case filepath.IsAbs(raw):
return raw
default:
if baseDir == "" {
return ""
}
return filepath.Join(baseDir, raw)
}
}
// ResolvePaths splits a comma-separated list of paths and
// resolves each entry independently. Empty entries and entries
// that resolve to empty strings are skipped.
func ResolvePaths(raw, baseDir string) []string {
if strings.TrimSpace(raw) == "" {
return nil
}
parts := strings.Split(raw, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
if resolved := ResolvePath(p, baseDir); resolved != "" {
out = append(out, resolved)
}
}
return out
}
-152
View File
@@ -1,152 +0,0 @@
package agentcontextconfig_test
import (
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentcontextconfig"
)
// platformAbsPath constructs an absolute path that is valid
// on the current platform. On Windows paths must include a
// drive letter to be considered absolute.
func platformAbsPath(parts ...string) string {
if runtime.GOOS == "windows" {
return `C:\` + filepath.Join(parts...)
}
return "/" + filepath.Join(parts...)
}
func TestResolvePath(t *testing.T) { //nolint:tparallel // subtests using t.Setenv cannot be parallel
t.Run("EmptyInput", func(t *testing.T) {
t.Parallel()
require.Equal(t, "", agentcontextconfig.ResolvePath("", platformAbsPath("base")))
})
t.Run("WhitespaceOnly", func(t *testing.T) {
t.Parallel()
require.Equal(t, "", agentcontextconfig.ResolvePath(" ", platformAbsPath("base")))
})
// Tests that use t.Setenv cannot be parallel.
t.Run("TildeAlone", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
got := agentcontextconfig.ResolvePath("~", platformAbsPath("base"))
require.Equal(t, fakeHome, got)
})
t.Run("TildeSlashPath", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
got := agentcontextconfig.ResolvePath("~/docs/readme", platformAbsPath("base"))
require.Equal(t, filepath.Join(fakeHome, "docs", "readme"), got)
})
t.Run("AbsolutePath", func(t *testing.T) {
t.Parallel()
p := platformAbsPath("etc", "coder")
got := agentcontextconfig.ResolvePath(p, platformAbsPath("base"))
require.Equal(t, p, got)
})
t.Run("RelativePath", func(t *testing.T) {
t.Parallel()
base := platformAbsPath("work")
got := agentcontextconfig.ResolvePath("foo/bar", base)
require.Equal(t, filepath.Join(base, "foo", "bar"), got)
})
t.Run("RelativePathWithWhitespace", func(t *testing.T) {
t.Parallel()
base := platformAbsPath("work")
got := agentcontextconfig.ResolvePath(" foo/bar ", base)
require.Equal(t, filepath.Join(base, "foo", "bar"), got)
})
t.Run("RelativePathWithEmptyBaseDir", func(t *testing.T) {
t.Parallel()
got := agentcontextconfig.ResolvePath(".agents/skills", "")
require.Equal(t, "", got)
})
}
func TestResolvePath_HomeUnset(t *testing.T) {
// Cannot be parallel — modifies HOME env var.
t.Setenv("HOME", "")
// Also clear USERPROFILE for Windows compatibility.
t.Setenv("USERPROFILE", "")
require.Equal(t, "", agentcontextconfig.ResolvePath("~", platformAbsPath("base")))
require.Equal(t, "", agentcontextconfig.ResolvePath("~/docs", platformAbsPath("base")))
}
func TestResolvePaths(t *testing.T) { //nolint:tparallel // subtests using t.Setenv cannot be parallel
t.Run("EmptyString", func(t *testing.T) {
t.Parallel()
require.Nil(t, agentcontextconfig.ResolvePaths("", platformAbsPath("base")))
})
t.Run("WhitespaceOnly", func(t *testing.T) {
t.Parallel()
require.Nil(t, agentcontextconfig.ResolvePaths(" ", platformAbsPath("base")))
})
t.Run("SingleEntry", func(t *testing.T) {
t.Parallel()
p := platformAbsPath("abs", "path")
got := agentcontextconfig.ResolvePaths(p, platformAbsPath("base"))
require.Equal(t, []string{p}, got)
})
// Tests that use t.Setenv cannot be parallel.
t.Run("MultipleEntries", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
b := platformAbsPath("b")
base := platformAbsPath("base")
got := agentcontextconfig.ResolvePaths("~/a,"+b+",rel", base)
require.Equal(t, []string{
filepath.Join(fakeHome, "a"),
b,
filepath.Join(base, "rel"),
}, got)
})
t.Run("TrimsWhitespace", func(t *testing.T) {
t.Parallel()
a := platformAbsPath("a")
b := platformAbsPath("b")
got := agentcontextconfig.ResolvePaths(" "+a+" , "+b+" ", platformAbsPath("base"))
require.Equal(t, []string{a, b}, got)
})
t.Run("SkipsEmptyEntries", func(t *testing.T) {
t.Parallel()
a := platformAbsPath("a")
b := platformAbsPath("b")
got := agentcontextconfig.ResolvePaths(a+",,"+b+",", platformAbsPath("base"))
require.Equal(t, []string{a, b}, got)
})
t.Run("TrailingComma", func(t *testing.T) {
t.Parallel()
p := platformAbsPath("only")
got := agentcontextconfig.ResolvePaths(p+",", platformAbsPath("base"))
require.Equal(t, []string{p}, got)
})
t.Run("RelativePathSkippedWhenBaseDirEmpty", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
got := agentcontextconfig.ResolvePaths("~/.coder,.agents/skills", "")
require.Equal(t, []string{filepath.Join(fakeHome, ".coder")}, got)
})
}
@@ -1,17 +1,12 @@
package agentdesktop
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"strconv"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentssh"
@@ -52,9 +47,6 @@ type API struct {
logger slog.Logger
desktop Desktop
clock quartz.Clock
closeMu sync.Mutex
closed bool
}
// NewAPI creates a new desktop streaming API.
@@ -74,10 +66,6 @@ func (a *API) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/vnc", a.handleDesktopVNC)
r.Post("/action", a.handleAction)
r.Route("/recording", func(r chi.Router) {
r.Post("/start", a.handleRecordingStart)
r.Post("/stop", a.handleRecordingStop)
})
return r
}
@@ -128,9 +116,6 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
handlerStart := a.clock.Now()
// Update last desktop action timestamp for idle recording monitor.
a.desktop.RecordActivity()
// Ensure the desktop is running and grab native dimensions.
cfg, err := a.desktop.Start(ctx)
if err != nil {
@@ -495,150 +480,9 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
// Close shuts down the desktop session if one is running.
func (a *API) Close() error {
a.closeMu.Lock()
if a.closed {
a.closeMu.Unlock()
return nil
}
a.closed = true
a.closeMu.Unlock()
return a.desktop.Close()
}
// decodeRecordingRequest decodes and validates a recording request
// from the HTTP body, returning the recording ID. Returns false if
// the request was invalid and an error response was already written.
func (*API) decodeRecordingRequest(rw http.ResponseWriter, r *http.Request) (string, bool) {
ctx := r.Context()
var req struct {
RecordingID string `json:"recording_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to decode request body.",
Detail: err.Error(),
})
return "", false
}
if req.RecordingID == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing recording_id.",
})
return "", false
}
if _, err := uuid.Parse(req.RecordingID); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid recording_id format.",
Detail: "recording_id must be a valid UUID.",
})
return "", false
}
return req.RecordingID, true
}
func (a *API) handleRecordingStart(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
recordingID, ok := a.decodeRecordingRequest(rw, r)
if !ok {
return
}
a.closeMu.Lock()
if a.closed {
a.closeMu.Unlock()
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
Message: "Desktop API is shutting down.",
})
return
}
a.closeMu.Unlock()
if err := a.desktop.StartRecording(ctx, recordingID); err != nil {
if errors.Is(err, ErrDesktopClosed) {
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
Message: "Desktop API is shutting down.",
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start recording.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: "Recording started.",
})
}
func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
recordingID, ok := a.decodeRecordingRequest(rw, r)
if !ok {
return
}
a.closeMu.Lock()
if a.closed {
a.closeMu.Unlock()
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
Message: "Desktop API is shutting down.",
})
return
}
a.closeMu.Unlock()
// Stop recording (idempotent).
// Use a context detached from the HTTP request so that if the
// connection drops, the recording process can still shut down
// gracefully. WithoutCancel preserves request-scoped values.
stopCtx, stopCancel := context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
defer stopCancel()
artifact, err := a.desktop.StopRecording(stopCtx, recordingID)
if err != nil {
if errors.Is(err, ErrUnknownRecording) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Recording not found.",
Detail: err.Error(),
})
return
}
if errors.Is(err, ErrRecordingCorrupted) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Recording is corrupted.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to stop recording.",
Detail: err.Error(),
})
return
}
defer artifact.Reader.Close()
if artifact.Size > workspacesdk.MaxRecordingSize {
a.logger.Warn(ctx, "recording file exceeds maximum size",
slog.F("recording_id", recordingID),
slog.F("size", artifact.Size),
slog.F("max_size", workspacesdk.MaxRecordingSize),
)
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
Message: "Recording file exceeds maximum allowed size.",
})
return
}
rw.Header().Set("Content-Type", "video/mp4")
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
rw.WriteHeader(http.StatusOK)
_, _ = io.Copy(rw, artifact.Reader)
}
// coordFromAction extracts the coordinate pair from a DesktopAction,
// returning an error if the coordinate field is missing.
func coordFromAction(action DesktopAction) (x, y int, err error) {
+576
View File
@@ -0,0 +1,576 @@
package agentdesktop_test
import (
"bytes"
"context"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentdesktop"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
// Ensure fakeDesktop satisfies the Desktop interface at compile time.
var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
// fakeDesktop is a minimal Desktop implementation for unit tests.
type fakeDesktop struct {
startErr error
cursorPos [2]int
startCfg agentdesktop.DisplayConfig
vncConnErr error
screenshotErr error
screenshotRes agentdesktop.ScreenshotResult
lastShotOpts agentdesktop.ScreenshotOptions
closed bool
// Track calls for assertions.
lastMove [2]int
lastClick [3]int // x, y, button
lastScroll [4]int // x, y, dx, dy
lastKey string
lastTyped string
lastKeyDown string
lastKeyUp string
}
func (f *fakeDesktop) Start(context.Context) (agentdesktop.DisplayConfig, error) {
return f.startCfg, f.startErr
}
func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
return nil, f.vncConnErr
}
func (f *fakeDesktop) Screenshot(_ context.Context, opts agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
f.lastShotOpts = opts
return f.screenshotRes, f.screenshotErr
}
func (f *fakeDesktop) Move(_ context.Context, x, y int) error {
f.lastMove = [2]int{x, y}
return nil
}
func (f *fakeDesktop) Click(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
f.lastClick = [3]int{x, y, 1}
return nil
}
func (f *fakeDesktop) DoubleClick(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
f.lastClick = [3]int{x, y, 2}
return nil
}
func (*fakeDesktop) ButtonDown(context.Context, agentdesktop.MouseButton) error { return nil }
func (*fakeDesktop) ButtonUp(context.Context, agentdesktop.MouseButton) error { return nil }
func (f *fakeDesktop) Scroll(_ context.Context, x, y, dx, dy int) error {
f.lastScroll = [4]int{x, y, dx, dy}
return nil
}
func (*fakeDesktop) Drag(context.Context, int, int, int, int) error { return nil }
func (f *fakeDesktop) KeyPress(_ context.Context, key string) error {
f.lastKey = key
return nil
}
func (f *fakeDesktop) KeyDown(_ context.Context, key string) error {
f.lastKeyDown = key
return nil
}
func (f *fakeDesktop) KeyUp(_ context.Context, key string) error {
f.lastKeyUp = key
return nil
}
func (f *fakeDesktop) Type(_ context.Context, text string) error {
f.lastTyped = text
return nil
}
func (f *fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
return f.cursorPos[0], f.cursorPos[1], nil
}
func (f *fakeDesktop) Close() error {
f.closed = true
return nil
}
func TestHandleDesktopVNC_StartError(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{startErr: xerrors.New("no desktop")}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
var resp codersdk.Response
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Failed to start desktop session.", resp.Message)
}
func TestHandleAction_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
geometry := workspacesdk.DefaultDesktopGeometry()
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{
Width: geometry.NativeWidth,
Height: geometry.NativeHeight,
},
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{Action: "screenshot"}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
var result agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&result)
require.NoError(t, err)
assert.Equal(t, "screenshot", result.Output)
assert.Equal(t, "base64data", result.ScreenshotData)
assert.Equal(t, geometry.NativeWidth, result.ScreenshotWidth)
assert.Equal(t, geometry.NativeHeight, result.ScreenshotHeight)
assert.Equal(t, agentdesktop.ScreenshotOptions{
TargetWidth: geometry.NativeWidth,
TargetHeight: geometry.NativeHeight,
}, fake.lastShotOpts)
}
func TestHandleAction_ScreenshotUsesDeclaredDimensionsFromRequest(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
Action: "screenshot",
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, agentdesktop.ScreenshotOptions{TargetWidth: 1280, TargetHeight: 720}, fake.lastShotOpts)
var result agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&result)
require.NoError(t, err)
assert.Equal(t, 1280, result.ScreenshotWidth)
assert.Equal(t, 720, result.ScreenshotHeight)
}
func TestHandleAction_LeftClick(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{
Action: "left_click",
Coordinate: &[2]int{100, 200},
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
var resp agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "left_click action performed", resp.Output)
assert.Equal(t, [3]int{100, 200, 1}, fake.lastClick)
}
func TestHandleAction_UnknownAction(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{Action: "explode"}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestHandleAction_KeyAction(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
text := "Return"
body := agentdesktop.DesktopAction{
Action: "key",
Text: &text,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "Return", fake.lastKey)
}
func TestHandleAction_TypeAction(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
text := "hello world"
body := agentdesktop.DesktopAction{
Action: "type",
Text: &text,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "hello world", fake.lastTyped)
}
func TestHandleAction_HoldKey(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
mClk := quartz.NewMock(t)
trap := mClk.Trap().NewTimer("agentdesktop", "hold_key")
defer trap.Close()
api := agentdesktop.NewAPI(logger, fake, mClk)
defer api.Close()
text := "Shift_L"
dur := 100
body := agentdesktop.DesktopAction{
Action: "hold_key",
Text: &text,
Duration: &dur,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
done := make(chan struct{})
go func() {
defer close(done)
handler.ServeHTTP(rr, req)
}()
trap.MustWait(req.Context()).MustRelease(req.Context())
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
<-done
assert.Equal(t, http.StatusOK, rr.Code)
var resp agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "hold_key action performed", resp.Output)
assert.Equal(t, "Shift_L", fake.lastKeyDown)
assert.Equal(t, "Shift_L", fake.lastKeyUp)
}
func TestHandleAction_HoldKeyMissingText(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{Action: "hold_key"}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
var resp codersdk.Response
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Missing \"text\" for hold_key action.", resp.Message)
}
func TestHandleAction_ScrollDown(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
dir := "down"
amount := 5
body := agentdesktop.DesktopAction{
Action: "scroll",
Coordinate: &[2]int{500, 400},
ScrollDirection: &dir,
ScrollAmount: &amount,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
}
func TestHandleAction_CoordinateScaling(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
Action: "mouse_move",
Coordinate: &[2]int{640, 360},
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, 960, fake.lastMove[0])
assert.Equal(t, 540, fake.lastMove[1])
}
func TestHandleAction_CoordinateScalingClampsToLastPixel(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1366
sh := 768
body := agentdesktop.DesktopAction{
Action: "mouse_move",
Coordinate: &[2]int{1365, 767},
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, 1919, fake.lastMove[0])
assert.Equal(t, 1079, fake.lastMove[1])
}
func TestClose_DelegatesToDesktop(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{}
api := agentdesktop.NewAPI(logger, fake, nil)
err := api.Close()
require.NoError(t, err)
assert.True(t, fake.closed)
}
func TestClose_PreventsNewSessions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{}
api := agentdesktop.NewAPI(logger, fake, nil)
err := api.Close()
require.NoError(t, err)
fake.startErr = xerrors.New("desktop is closed")
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
}
func TestHandleAction_CursorPositionReturnsDeclaredCoordinates(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
cursorPos: [2]int{960, 540},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
Action: "cursor_position",
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
var resp agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
// Native (960,540) in 1920x1080 should map to declared space in 1280x720.
assert.Equal(t, "x=640,y=360", resp.Output)
}
@@ -2,10 +2,7 @@ package agentdesktop
import (
"context"
"io"
"net"
"golang.org/x/xerrors"
)
// Desktop abstracts a virtual desktop session running inside a workspace.
@@ -61,52 +58,10 @@ type Desktop interface {
// CursorPosition returns the current cursor coordinates.
CursorPosition(ctx context.Context) (x, y int, err error)
// RecordActivity marks the desktop as having received user
// interaction, resetting the idle-recording timer.
RecordActivity()
// StartRecording begins recording the desktop to an MP4 file
// using the caller-provided recording ID. Safe to call
// repeatedly - active recordings continue unchanged, stopped
// recordings are discarded and restarted. Concurrent recordings
// are supported.
StartRecording(ctx context.Context, recordingID string) error
// StopRecording finalizes the recording identified by the given
// ID. Idempotent - safe to call on an already-stopped recording.
// Returns a RecordingArtifact that the caller can stream. The
// caller must close the artifact when done. Returns an error if
// the recording ID is unknown.
StopRecording(ctx context.Context, recordingID string) (*RecordingArtifact, error)
// Close shuts down the desktop session and cleans up resources.
Close() error
}
// ErrUnknownRecording is returned by StopRecording when the
// recording ID is not recognized.
var ErrUnknownRecording = xerrors.New("unknown recording ID")
// ErrDesktopClosed is returned when an operation is attempted on a
// closed desktop session.
var ErrDesktopClosed = xerrors.New("desktop closed")
// ErrRecordingCorrupted is returned by StopRecording when the
// recording process was force-killed and the artifact is likely
// incomplete or corrupt.
var ErrRecordingCorrupted = xerrors.New("recording corrupted: process was force-killed")
// RecordingArtifact is a finalized recording returned by StopRecording.
// The caller streams the artifact and must call Close when done. The
// artifact remains valid even if the same recording ID is restarted
// or the desktop is closed while the caller is reading.
type RecordingArtifact struct {
// Reader is the MP4 content. Callers must close it when done.
Reader io.ReadCloser
// Size is the byte length of the MP4 content.
Size int64
}
// DisplayConfig describes a running desktop session.
type DisplayConfig struct {
Width int // native width in pixels
+399
View File
@@ -0,0 +1,399 @@
package agentdesktop
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// portableDesktopOutput is the JSON output from
// `portabledesktop up --json`.
type portableDesktopOutput struct {
VNCPort int `json:"vncPort"`
Geometry string `json:"geometry"` // e.g. "1920x1080"
}
// desktopSession tracks a running portabledesktop process.
type desktopSession struct {
cmd *exec.Cmd
vncPort int
width int // native width, parsed from geometry
height int // native height, parsed from geometry
display int // X11 display number, -1 if not available
cancel context.CancelFunc
}
// cursorOutput is the JSON output from `portabledesktop cursor --json`.
type cursorOutput struct {
X int `json:"x"`
Y int `json:"y"`
}
// screenshotOutput is the JSON output from
// `portabledesktop screenshot --json`.
type screenshotOutput struct {
Data string `json:"data"`
}
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
scriptBinDir string // coder script bin directory
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. scriptBinDir is
// the coder script bin directory checked for the binary.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
scriptBinDir string,
) Desktop {
return &portableDesktop{
logger: logger,
execer: execer,
scriptBinDir: scriptBinDir,
}
}
// Start launches the desktop session (idempotent).
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return DisplayConfig{}, xerrors.New("desktop is closed")
}
if err := p.ensureBinary(ctx); err != nil {
return DisplayConfig{}, xerrors.Errorf("ensure portabledesktop binary: %w", err)
}
// If we have an existing session, check if it's still alive.
if p.session != nil {
if !(p.session.cmd.ProcessState != nil && p.session.cmd.ProcessState.Exited()) {
return DisplayConfig{
Width: p.session.width,
Height: p.session.height,
VNCPort: p.session.vncPort,
Display: p.session.display,
}, nil
}
// Process died — clean up and recreate.
p.logger.Warn(ctx, "portabledesktop process died, recreating session")
p.session.cancel()
p.session = nil
}
// Spawn portabledesktop up --json.
sessionCtx, sessionCancel := context.WithCancel(context.Background())
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopNativeWidth, workspacesdk.DesktopNativeHeight))
stdout, err := cmd.StdoutPipe()
if err != nil {
sessionCancel()
return DisplayConfig{}, xerrors.Errorf("create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
sessionCancel()
return DisplayConfig{}, xerrors.Errorf("start portabledesktop: %w", err)
}
// Parse the JSON output to get VNC port and geometry.
var output portableDesktopOutput
if err := json.NewDecoder(stdout).Decode(&output); err != nil {
sessionCancel()
_ = cmd.Process.Kill()
_ = cmd.Wait()
return DisplayConfig{}, xerrors.Errorf("parse portabledesktop output: %w", err)
}
if output.VNCPort == 0 {
sessionCancel()
_ = cmd.Process.Kill()
_ = cmd.Wait()
return DisplayConfig{}, xerrors.New("portabledesktop returned port 0")
}
var w, h int
if output.Geometry != "" {
if _, err := fmt.Sscanf(output.Geometry, "%dx%d", &w, &h); err != nil {
p.logger.Warn(ctx, "failed to parse geometry, using defaults",
slog.F("geometry", output.Geometry),
slog.Error(err),
)
}
}
p.logger.Info(ctx, "started portabledesktop session",
slog.F("vnc_port", output.VNCPort),
slog.F("width", w),
slog.F("height", h),
slog.F("pid", cmd.Process.Pid),
)
p.session = &desktopSession{
cmd: cmd,
vncPort: output.VNCPort,
width: w,
height: h,
display: -1,
cancel: sessionCancel,
}
return DisplayConfig{
Width: w,
Height: h,
VNCPort: output.VNCPort,
Display: -1,
}, nil
}
// VNCConn dials the desktop's VNC server and returns a raw
// net.Conn carrying RFB binary frames.
func (p *portableDesktop) VNCConn(_ context.Context) (net.Conn, error) {
p.mu.Lock()
session := p.session
p.mu.Unlock()
if session == nil {
return nil, xerrors.New("desktop session not started")
}
return net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", session.vncPort))
}
// Screenshot captures the current framebuffer as a base64-encoded PNG.
func (p *portableDesktop) Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error) {
args := []string{"screenshot", "--json"}
if opts.TargetWidth > 0 {
args = append(args, "--target-width", strconv.Itoa(opts.TargetWidth))
}
if opts.TargetHeight > 0 {
args = append(args, "--target-height", strconv.Itoa(opts.TargetHeight))
}
out, err := p.runCmd(ctx, args...)
if err != nil {
return ScreenshotResult{}, err
}
var result screenshotOutput
if err := json.Unmarshal([]byte(out), &result); err != nil {
return ScreenshotResult{}, xerrors.Errorf("parse screenshot output: %w", err)
}
return ScreenshotResult(result), nil
}
// Move moves the mouse cursor to absolute coordinates.
func (p *portableDesktop) Move(ctx context.Context, x, y int) error {
_, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y))
return err
}
// Click performs a mouse button click at the given coordinates.
func (p *portableDesktop) Click(ctx context.Context, x, y int, button MouseButton) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "click", string(button))
return err
}
// DoubleClick performs a double-click at the given coordinates.
func (p *portableDesktop) DoubleClick(ctx context.Context, x, y int, button MouseButton) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "click", string(button)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "click", string(button))
return err
}
// ButtonDown presses and holds a mouse button.
func (p *portableDesktop) ButtonDown(ctx context.Context, button MouseButton) error {
_, err := p.runCmd(ctx, "mouse", "down", string(button))
return err
}
// ButtonUp releases a mouse button.
func (p *portableDesktop) ButtonUp(ctx context.Context, button MouseButton) error {
_, err := p.runCmd(ctx, "mouse", "up", string(button))
return err
}
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
func (p *portableDesktop) Scroll(ctx context.Context, x, y, dx, dy int) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "scroll", strconv.Itoa(dx), strconv.Itoa(dy))
return err
}
// Drag moves from (startX,startY) to (endX,endY) while holding the
// left mouse button.
func (p *portableDesktop) Drag(ctx context.Context, startX, startY, endX, endY int) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(startX), strconv.Itoa(startY)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "down", string(MouseButtonLeft)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(endX), strconv.Itoa(endY)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "up", string(MouseButtonLeft))
return err
}
// KeyPress sends a key-down then key-up for a key combo string.
func (p *portableDesktop) KeyPress(ctx context.Context, keys string) error {
_, err := p.runCmd(ctx, "keyboard", "key", keys)
return err
}
// KeyDown presses and holds a key.
func (p *portableDesktop) KeyDown(ctx context.Context, key string) error {
_, err := p.runCmd(ctx, "keyboard", "down", key)
return err
}
// KeyUp releases a key.
func (p *portableDesktop) KeyUp(ctx context.Context, key string) error {
_, err := p.runCmd(ctx, "keyboard", "up", key)
return err
}
// Type types a string of text character-by-character.
func (p *portableDesktop) Type(ctx context.Context, text string) error {
_, err := p.runCmd(ctx, "keyboard", "type", text)
return err
}
// CursorPosition returns the current cursor coordinates.
func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err error) {
out, err := p.runCmd(ctx, "cursor", "--json")
if err != nil {
return 0, 0, err
}
var result cursorOutput
if err := json.Unmarshal([]byte(out), &result); err != nil {
return 0, 0, xerrors.Errorf("parse cursor output: %w", err)
}
return result.X, result.Y, nil
}
// Close shuts down the desktop session and cleans up resources.
func (p *portableDesktop) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
if p.session != nil {
p.session.cancel()
// Xvnc is a child process — killing it cleans up the X
// session.
_ = p.session.cmd.Process.Kill()
_ = p.session.cmd.Wait()
p.session = nil
}
return nil
}
// runCmd executes a portabledesktop subcommand and returns combined
// output. The caller must have previously called ensureBinary.
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
start := time.Now()
//nolint:gosec // args are constructed by the caller, not user input.
cmd := p.execer.CommandContext(ctx, p.binPath, args...)
out, err := cmd.CombinedOutput()
elapsed := time.Since(start)
if err != nil {
p.logger.Warn(ctx, "portabledesktop command failed",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
slog.Error(err),
slog.F("output", string(out)),
)
return "", xerrors.Errorf("portabledesktop %s: %w: %s", args[0], err, string(out))
}
if elapsed > 5*time.Second {
p.logger.Warn(ctx, "portabledesktop command slow",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
)
} else {
p.logger.Debug(ctx, "portabledesktop command completed",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
)
}
return string(out), nil
}
// ensureBinary resolves the portabledesktop binary from PATH or the
// coder script bin directory. It must be called while p.mu is held.
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
if p.binPath != "" {
return nil
}
// 1. Check PATH.
if path, err := exec.LookPath("portabledesktop"); err == nil {
p.logger.Info(ctx, "found portabledesktop in PATH",
slog.F("path", path),
)
p.binPath = path
return nil
}
// 2. Check the coder script bin directory.
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
// On Windows, permission bits don't indicate executability,
// so accept any regular file.
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
p.logger.Info(ctx, "found portabledesktop in script bin directory",
slog.F("path", scriptBinPath),
)
p.binPath = scriptBinPath
return nil
}
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
slog.F("path", scriptBinPath),
slog.F("mode", info.Mode().String()),
)
}
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
}
@@ -9,17 +9,13 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/pty"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
// recordedExecer implements agentexec.Execer by recording every
@@ -90,7 +86,6 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -122,7 +117,6 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -165,7 +159,6 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -191,7 +184,6 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -290,7 +282,6 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
err := tt.invoke(t.Context(), pd)
@@ -298,6 +289,7 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
cmds := rec.allCommands()
require.NotEmpty(t, cmds, "expected at least one command")
// Find at least one recorded command that contains
// all expected argument substrings.
found := false
@@ -375,7 +367,6 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
err := tt.invoke(t.Context(), pd)
@@ -432,7 +423,6 @@ func TestPortableDesktop_Close(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -455,7 +445,7 @@ func TestPortableDesktop_Close(t *testing.T) {
// Subsequent Start must fail.
_, err = pd.Start(ctx)
require.Error(t, err)
assert.Contains(t, err.Error(), "desktop closed")
assert.Contains(t, err.Error(), "desktop is closed")
}
// --- ensureBinary tests ---
@@ -549,410 +539,7 @@ func TestEnsureBinary_NotFound(t *testing.T) {
assert.Contains(t, err.Error(), "not found")
}
func TestPortableDesktop_StartRecording(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
cmds := rec.allCommands()
require.NotEmpty(t, cmds)
// Find the record command (not the up command).
found := false
for _, cmd := range cmds {
joined := strings.Join(cmd, " ")
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
found = true
break
}
}
assert.True(t, found, "expected a record command with the recording ID")
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StartRecording_ConcurrentLimit(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
for i := range maxConcurrentRecordings {
err := pd.StartRecording(ctx, uuid.New().String())
require.NoError(t, err, "recording %d should succeed", i)
}
err := pd.StartRecording(ctx, uuid.New().String())
require.Error(t, err)
assert.Contains(t, err.Error(), "too many concurrent recordings")
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
// Write a dummy MP4 file at the expected path so StopRecording
// can open it as an artifact.
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
t.Cleanup(func() { _ = os.Remove(filePath) })
artifact, err := pd.StopRecording(ctx, recID)
require.NoError(t, err)
defer artifact.Reader.Close()
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StopRecording_UnknownID(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
_, err := pd.StopRecording(ctx, uuid.New().String())
require.ErrorIs(t, err, ErrUnknownRecording)
require.NoError(t, pd.Close())
}
// Ensure that portableDesktop satisfies the Desktop interface at
// compile time. This uses the unexported type so it lives in the
// internal test package.
var _ Desktop = (*portableDesktop)(nil)
func TestPortableDesktop_IdleTimeout_StopsRecordings(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewMock(t)
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
// Install the trap before StartRecording so it is guaranteed
// to catch the idle monitor's NewTimer call regardless of
// goroutine scheduling.
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
// Verify recording is active.
pd.mu.Lock()
require.False(t, pd.recordings[recID].stopped)
pd.mu.Unlock()
// Wait for the idle monitor timer to be created and release
// it so the monitor enters its select loop.
trap.MustWait(ctx).MustRelease(ctx)
trap.Close()
// The stop-all path calls lockedStopRecordingProcess which
// creates a per-recording 15s stop_timeout timer.
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
// Advance past idle timeout to trigger the stop-all.
clk.Advance(idleTimeout)
// Wait for the stop timer to be created, then release it.
stopTrap.MustWait(ctx).MustRelease(ctx)
stopTrap.Close()
// The recording process should now be stopped.
require.Eventually(t, func() bool {
pd.mu.Lock()
defer pd.mu.Unlock()
rec, ok := pd.recordings[recID]
return ok && rec.stopped
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, pd.Close())
}
func TestPortableDesktop_IdleTimeout_ActivityResetsTimer(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewMock(t)
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
// Install the trap before StartRecording so it is guaranteed
// to catch the idle monitor's NewTimer call regardless of
// goroutine scheduling.
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
// Wait for the idle monitor timer to be created.
trap.MustWait(ctx).MustRelease(ctx)
trap.Close()
// Advance most of the way but not past the timeout.
clk.Advance(idleTimeout - time.Minute)
// Record activity to reset the timer.
pd.RecordActivity()
// Trap the Reset call that the idle monitor makes when it
// sees recent activity.
resetTrap := clk.Trap().TimerReset("agentdesktop", "recording_idle")
// Advance past the original idle timeout deadline. The
// monitor should see the recent activity and reset instead
// of stopping.
clk.Advance(time.Minute)
resetTrap.MustWait(ctx).MustRelease(ctx)
resetTrap.Close()
// Recording should still be active because activity was
// recorded.
pd.mu.Lock()
require.False(t, pd.recordings[recID].stopped)
pd.mu.Unlock()
require.NoError(t, pd.Close())
}
func TestPortableDesktop_IdleTimeout_MultipleRecordings(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewMock(t)
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID1 := uuid.New().String()
recID2 := uuid.New().String()
// Trap idle timer creation for both recordings.
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
err := pd.StartRecording(ctx, recID1)
require.NoError(t, err)
// Wait for first recording's idle timer.
trap.MustWait(ctx).MustRelease(ctx)
err = pd.StartRecording(ctx, recID2)
require.NoError(t, err)
// Wait for second recording's idle timer.
trap.MustWait(ctx).MustRelease(ctx)
trap.Close()
// Trap the stop timers that will be created when idle fires.
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
// Advance past idle timeout.
clk.Advance(idleTimeout)
// Wait for both stop timers.
stopTrap.MustWait(ctx).MustRelease(ctx)
stopTrap.MustWait(ctx).MustRelease(ctx)
stopTrap.Close()
// Both recordings should be stopped.
require.Eventually(t, func() bool {
pd.mu.Lock()
defer pd.mu.Unlock()
r1, ok1 := pd.recordings[recID1]
r2, ok2 := pd.recordings[recID2]
return ok1 && r1.stopped && ok2 && r2.stopped
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StartRecording_ReturnsErrDesktopClosed(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
// Start and close the desktop so it's in the closed state.
ctx := t.Context()
_, err := pd.Start(ctx)
require.NoError(t, err)
require.NoError(t, pd.Close())
// StartRecording should now return ErrDesktopClosed.
err = pd.StartRecording(ctx, uuid.New().String())
require.ErrorIs(t, err, ErrDesktopClosed)
}
func TestPortableDesktop_Start_ReturnsErrDesktopClosed(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: quartz.NewReal(),
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(pd.clock.Now().UnixNano())
ctx := t.Context()
_, err := pd.Start(ctx)
require.NoError(t, err)
require.NoError(t, pd.Close())
_, err = pd.Start(ctx)
require.ErrorIs(t, err, ErrDesktopClosed)
}
-5
View File
@@ -148,11 +148,6 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
for k, v := range req.Env {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
}
// Propagate the chat ID so child processes (e.g.
// GIT_ASKPASS) can send it back to the server.
if chatID != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_CHAT_ID=%s", chatID))
}
if err := cmd.Start(); err != nil {
cancel()
+1 -1
View File
@@ -211,7 +211,7 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
require.NoError(t, err)
stderr, err := sess.StderrPipe()
require.NoError(t, err)
require.NoError(t, sess.Start("sh"))
require.NoError(t, sess.Shell())
// The SSH server lazily starts the session. We need to write a command
// and read back to ensure the X11 forwarding is started.
-2
View File
@@ -31,8 +31,6 @@ func (a *agent) apiHandler() http.Handler {
r.Mount("/api/v0/git", a.gitAPI.Routes())
r.Mount("/api/v0/processes", a.processAPI.Routes())
r.Mount("/api/v0/desktop", a.desktopAPI.Routes())
r.Mount("/api/v0/mcp", a.mcpAPI.Routes())
r.Mount("/api/v0/context-config", a.contextConfigAPI.Routes())
if a.devcontainers {
r.Mount("/api/v0/containers", a.containerAPI.Routes())
File diff suppressed because it is too large Load Diff
-768
View File
@@ -1,768 +0,0 @@
package agentdesktop
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
// portableDesktopOutput is the JSON output from
// `portabledesktop up --json`.
type portableDesktopOutput struct {
VNCPort int `json:"vncPort"`
Geometry string `json:"geometry"` // e.g. "1920x1080"
}
// desktopSession tracks a running portabledesktop process.
type desktopSession struct {
cmd *exec.Cmd
vncPort int
width int // native width, parsed from geometry
height int // native height, parsed from geometry
display int // X11 display number, -1 if not available
cancel context.CancelFunc
}
// cursorOutput is the JSON output from `portabledesktop cursor --json`.
type cursorOutput struct {
X int `json:"x"`
Y int `json:"y"`
}
// screenshotOutput is the JSON output from
// `portabledesktop screenshot --json`.
type screenshotOutput struct {
Data string `json:"data"`
}
// recordingProcess tracks a single desktop recording subprocess.
type recordingProcess struct {
cmd *exec.Cmd
filePath string
stopped bool
killed bool // true when the process was SIGKILLed
done chan struct{} // closed when cmd.Wait() returns
waitErr error // set before done is closed
stopOnce sync.Once
idleCancel context.CancelFunc // cancels the per-recording idle goroutine
idleDone chan struct{} // closed when idle goroutine exits
}
// maxConcurrentRecordings is the maximum number of active (non-stopped)
// recordings allowed at once. This prevents resource exhaustion.
const maxConcurrentRecordings = 5
// idleTimeout is the duration of desktop inactivity after which all
// active recordings are automatically stopped.
const idleTimeout = 10 * time.Minute
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
scriptBinDir string // coder script bin directory
clock quartz.Clock
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
recordings map[string]*recordingProcess // guarded by mu
lastDesktopActionAt atomic.Int64
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. scriptBinDir is
// the coder script bin directory checked for the binary. If clk is
// nil, a real clock is used.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
scriptBinDir string,
clk quartz.Clock,
) Desktop {
if clk == nil {
clk = quartz.NewReal()
}
pd := &portableDesktop{
logger: logger,
execer: execer,
scriptBinDir: scriptBinDir,
clock: clk,
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
return pd
}
// Start launches the desktop session (idempotent).
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return DisplayConfig{}, ErrDesktopClosed
}
if err := p.ensureBinary(ctx); err != nil {
return DisplayConfig{}, xerrors.Errorf("ensure portabledesktop binary: %w", err)
}
// If we have an existing session, check if it's still alive.
if p.session != nil {
if !(p.session.cmd.ProcessState != nil && p.session.cmd.ProcessState.Exited()) {
return DisplayConfig{
Width: p.session.width,
Height: p.session.height,
VNCPort: p.session.vncPort,
Display: p.session.display,
}, nil
}
// Process died — clean up and recreate.
p.logger.Warn(ctx, "portabledesktop process died, recreating session")
p.session.cancel()
p.session = nil
}
// Spawn portabledesktop up --json.
sessionCtx, sessionCancel := context.WithCancel(context.Background())
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopNativeWidth, workspacesdk.DesktopNativeHeight))
stdout, err := cmd.StdoutPipe()
if err != nil {
sessionCancel()
return DisplayConfig{}, xerrors.Errorf("create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
sessionCancel()
return DisplayConfig{}, xerrors.Errorf("start portabledesktop: %w", err)
}
// Parse the JSON output to get VNC port and geometry.
var output portableDesktopOutput
if err := json.NewDecoder(stdout).Decode(&output); err != nil {
sessionCancel()
_ = cmd.Process.Kill()
_ = cmd.Wait()
return DisplayConfig{}, xerrors.Errorf("parse portabledesktop output: %w", err)
}
if output.VNCPort == 0 {
sessionCancel()
_ = cmd.Process.Kill()
_ = cmd.Wait()
return DisplayConfig{}, xerrors.New("portabledesktop returned port 0")
}
var w, h int
if output.Geometry != "" {
if _, err := fmt.Sscanf(output.Geometry, "%dx%d", &w, &h); err != nil {
p.logger.Warn(ctx, "failed to parse geometry, using defaults",
slog.F("geometry", output.Geometry),
slog.Error(err),
)
}
}
p.logger.Info(ctx, "started portabledesktop session",
slog.F("vnc_port", output.VNCPort),
slog.F("width", w),
slog.F("height", h),
slog.F("pid", cmd.Process.Pid),
)
p.session = &desktopSession{
cmd: cmd,
vncPort: output.VNCPort,
width: w,
height: h,
display: -1,
cancel: sessionCancel,
}
return DisplayConfig{
Width: w,
Height: h,
VNCPort: output.VNCPort,
Display: -1,
}, nil
}
// VNCConn dials the desktop's VNC server and returns a raw
// net.Conn carrying RFB binary frames.
func (p *portableDesktop) VNCConn(_ context.Context) (net.Conn, error) {
p.mu.Lock()
session := p.session
p.mu.Unlock()
if session == nil {
return nil, xerrors.New("desktop session not started")
}
return net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", session.vncPort))
}
// Screenshot captures the current framebuffer as a base64-encoded PNG.
func (p *portableDesktop) Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error) {
args := []string{"screenshot", "--json"}
if opts.TargetWidth > 0 {
args = append(args, "--target-width", strconv.Itoa(opts.TargetWidth))
}
if opts.TargetHeight > 0 {
args = append(args, "--target-height", strconv.Itoa(opts.TargetHeight))
}
out, err := p.runCmd(ctx, args...)
if err != nil {
return ScreenshotResult{}, err
}
var result screenshotOutput
if err := json.Unmarshal([]byte(out), &result); err != nil {
return ScreenshotResult{}, xerrors.Errorf("parse screenshot output: %w", err)
}
return ScreenshotResult(result), nil
}
// Move moves the mouse cursor to absolute coordinates.
func (p *portableDesktop) Move(ctx context.Context, x, y int) error {
_, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y))
return err
}
// Click performs a mouse button click at the given coordinates.
func (p *portableDesktop) Click(ctx context.Context, x, y int, button MouseButton) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "click", string(button))
return err
}
// DoubleClick performs a double-click at the given coordinates.
func (p *portableDesktop) DoubleClick(ctx context.Context, x, y int, button MouseButton) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "click", string(button)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "click", string(button))
return err
}
// ButtonDown presses and holds a mouse button.
func (p *portableDesktop) ButtonDown(ctx context.Context, button MouseButton) error {
_, err := p.runCmd(ctx, "mouse", "down", string(button))
return err
}
// ButtonUp releases a mouse button.
func (p *portableDesktop) ButtonUp(ctx context.Context, button MouseButton) error {
_, err := p.runCmd(ctx, "mouse", "up", string(button))
return err
}
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
func (p *portableDesktop) Scroll(ctx context.Context, x, y, dx, dy int) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "scroll", strconv.Itoa(dx), strconv.Itoa(dy))
return err
}
// Drag moves from (startX,startY) to (endX,endY) while holding the
// left mouse button.
func (p *portableDesktop) Drag(ctx context.Context, startX, startY, endX, endY int) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(startX), strconv.Itoa(startY)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "down", string(MouseButtonLeft)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(endX), strconv.Itoa(endY)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "up", string(MouseButtonLeft))
return err
}
// KeyPress sends a key-down then key-up for a key combo string.
func (p *portableDesktop) KeyPress(ctx context.Context, keys string) error {
_, err := p.runCmd(ctx, "keyboard", "key", keys)
return err
}
// KeyDown presses and holds a key.
func (p *portableDesktop) KeyDown(ctx context.Context, key string) error {
_, err := p.runCmd(ctx, "keyboard", "down", key)
return err
}
// KeyUp releases a key.
func (p *portableDesktop) KeyUp(ctx context.Context, key string) error {
_, err := p.runCmd(ctx, "keyboard", "up", key)
return err
}
// Type types a string of text character-by-character.
func (p *portableDesktop) Type(ctx context.Context, text string) error {
_, err := p.runCmd(ctx, "keyboard", "type", text)
return err
}
// CursorPosition returns the current cursor coordinates.
func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err error) {
out, err := p.runCmd(ctx, "cursor", "--json")
if err != nil {
return 0, 0, err
}
var result cursorOutput
if err := json.Unmarshal([]byte(out), &result); err != nil {
return 0, 0, xerrors.Errorf("parse cursor output: %w", err)
}
return result.X, result.Y, nil
}
// StartRecording begins recording the desktop to an MP4 file.
// Three-state idempotency: active recordings are no-ops,
// completed recordings are discarded and restarted.
func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string) error {
// Ensure the desktop session is running before acquiring the
// recording lock. Start is independently locked and idempotent.
if _, err := p.Start(ctx); err != nil {
return xerrors.Errorf("ensure desktop session: %w", err)
}
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return ErrDesktopClosed
}
// Three-state idempotency:
// - Active recording → no-op, continue recording.
// - Completed recording → discard old file, start fresh.
// - Unknown ID → fall through to start a new recording.
if rec, ok := p.recordings[recordingID]; ok {
if !rec.stopped {
select {
case <-rec.done:
// Process exited unexpectedly; treat as completed
// so we fall through to discard the old file and
// restart.
default:
// Active recording - no-op, continue recording.
return nil
}
}
// Completed recording - discard old file, start fresh.
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
p.logger.Warn(ctx, "failed to remove old recording file",
slog.F("recording_id", recordingID),
slog.F("file_path", rec.filePath),
slog.Error(err),
)
}
delete(p.recordings, recordingID)
}
// Check concurrent recording limit.
if p.lockedActiveRecordingCount() >= maxConcurrentRecordings {
return xerrors.Errorf("too many concurrent recordings (max %d)", maxConcurrentRecordings)
}
// GC sweep: remove stopped recordings with stale files.
p.lockedCleanStaleRecordings(ctx)
if err := p.ensureBinary(ctx); err != nil {
return xerrors.Errorf("ensure portabledesktop binary: %w", err)
}
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
// Use a background context so the process outlives the HTTP
// request that triggered it.
procCtx, procCancel := context.WithCancel(context.Background())
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
cmd := p.execer.CommandContext(procCtx, p.binPath, "record",
// The following options are used to speed up the recording when the desktop is idle.
// They were taken out of an example in the portabledesktop repo.
// There's likely room for improvement to optimize the values.
"--idle-speedup", "20",
"--idle-min-duration", "0.35",
"--idle-noise-tolerance", "-38dB",
filePath)
if err := cmd.Start(); err != nil {
procCancel()
return xerrors.Errorf("start recording process: %w", err)
}
rec := &recordingProcess{
cmd: cmd,
filePath: filePath,
done: make(chan struct{}),
}
go func() {
rec.waitErr = cmd.Wait()
close(rec.done)
// avoid a context resource leak by canceling the context
procCancel()
}()
p.recordings[recordingID] = rec
p.logger.Info(ctx, "started desktop recording",
slog.F("recording_id", recordingID),
slog.F("file_path", filePath),
slog.F("pid", cmd.Process.Pid),
)
// Record activity so a recording started on an already-idle
// desktop does not stop immediately.
p.lastDesktopActionAt.Store(p.clock.Now().UnixNano())
// Spawn a per-recording idle goroutine.
idleCtx, idleCancel := context.WithCancel(context.Background())
rec.idleCancel = idleCancel
rec.idleDone = make(chan struct{})
go func() {
defer close(rec.idleDone)
p.monitorRecordingIdle(idleCtx, rec)
}()
return nil
}
// StopRecording finalizes the recording. Idempotent - safe to call
// on an already-stopped recording. Returns a RecordingArtifact
// that the caller can stream. The caller must close the Reader
// on the returned artifact to avoid leaking file descriptors.
func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string) (*RecordingArtifact, error) {
p.mu.Lock()
rec, ok := p.recordings[recordingID]
if !ok {
p.mu.Unlock()
return nil, ErrUnknownRecording
}
p.lockedStopRecordingProcess(ctx, rec, false)
killed := rec.killed
p.mu.Unlock()
p.logger.Info(ctx, "stopped desktop recording",
slog.F("recording_id", recordingID),
slog.F("file_path", rec.filePath),
)
if killed {
return nil, ErrRecordingCorrupted
}
// Open the file and return an artifact. Each call opens a fresh
// file descriptor so the caller is insulated from restarts and
// desktop close.
f, err := os.Open(rec.filePath)
if err != nil {
return nil, xerrors.Errorf("open recording artifact: %w", err)
}
info, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, xerrors.Errorf("stat recording artifact: %w", err)
}
return &RecordingArtifact{
Reader: f,
Size: info.Size(),
}, nil
}
// lockedStopRecordingProcess stops a single recording via stopOnce.
// It sends SIGINT, waits up to 15 seconds for graceful exit, then
// SIGKILLs. When force is true the process is SIGKILLed immediately
// without attempting a graceful shutdown. Must be called while p.mu
// is held; the lock is held for the full duration so that no
// concurrent StopRecording caller can read rec.stopped = true
// before the process has finished writing the MP4 file.
//
//nolint:revive // force flag keeps shared stopOnce/cleanup logic in one place.
func (p *portableDesktop) lockedStopRecordingProcess(ctx context.Context, rec *recordingProcess, force bool) {
rec.stopOnce.Do(func() {
if force {
_ = rec.cmd.Process.Kill()
rec.killed = true
} else {
_ = interruptRecordingProcess(rec.cmd.Process)
timer := p.clock.NewTimer(15*time.Second, "agentdesktop", "stop_timeout")
defer timer.Stop()
select {
case <-rec.done:
case <-ctx.Done():
_ = rec.cmd.Process.Kill()
rec.killed = true
case <-timer.C:
_ = rec.cmd.Process.Kill()
rec.killed = true
}
}
rec.stopped = true
if rec.idleCancel != nil {
rec.idleCancel()
}
})
// NOTE: We intentionally do not wait on rec.done here.
// If goleak is added to this package's tests, this may
// need revisiting to avoid flakes.
}
// lockedActiveRecordingCount returns the number of recordings that
// are still actively running. Must be called while p.mu is held.
// The max concurrency is low (maxConcurrentRecordings = 5), so a
// full scan is cheap and avoids maintaining a separate counter.
func (p *portableDesktop) lockedActiveRecordingCount() int {
active := 0
for _, rec := range p.recordings {
if rec.stopped {
continue
}
select {
case <-rec.done:
default:
active++
}
}
return active
}
// lockedCleanStaleRecordings removes stopped recordings whose temp
// files are older than one hour. Must be called while p.mu is held.
func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
for id, rec := range p.recordings {
if !rec.stopped {
continue
}
info, err := os.Stat(rec.filePath)
if err != nil {
// File already removed or inaccessible; drop entry.
delete(p.recordings, id)
continue
}
if p.clock.Since(info.ModTime()) > time.Hour {
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
p.logger.Warn(ctx, "failed to remove stale recording file",
slog.F("recording_id", id),
slog.F("file_path", rec.filePath),
slog.Error(err),
)
}
delete(p.recordings, id)
}
}
}
// Close shuts down the desktop session and cleans up resources.
func (p *portableDesktop) Close() error {
p.mu.Lock()
p.closed = true
// Force-kill all active recordings. The stopOnce inside
// lockedStopRecordingProcess makes this safe for
// already-stopped recordings.
for _, rec := range p.recordings {
p.lockedStopRecordingProcess(context.Background(), rec, true)
}
// Snapshot recording file paths and idle goroutine channels
// for cleanup, then clear the map.
type recEntry struct {
id string
filePath string
idleDone chan struct{}
}
var allRecs []recEntry
for id, rec := range p.recordings {
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
delete(p.recordings, id)
}
session := p.session
p.session = nil
p.mu.Unlock()
// Wait for all per-recording idle goroutines to exit.
for _, entry := range allRecs {
if entry.idleDone != nil {
<-entry.idleDone
}
}
// Remove all recording files and wait for the session to
// exit with a timeout so a slow filesystem or hung process
// cannot block agent shutdown indefinitely.
cleanupDone := make(chan struct{})
go func() {
defer close(cleanupDone)
for _, entry := range allRecs {
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
p.logger.Warn(context.Background(), "failed to remove recording file on close",
slog.F("recording_id", entry.id),
slog.F("file_path", entry.filePath),
slog.Error(err),
)
}
}
if session != nil {
session.cancel()
if err := session.cmd.Process.Kill(); err != nil {
p.logger.Warn(context.Background(), "failed to kill portabledesktop process",
slog.Error(err),
)
}
if err := session.cmd.Wait(); err != nil {
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) {
p.logger.Warn(context.Background(), "portabledesktop process exited with error",
slog.Error(err),
)
}
}
}
}()
timer := p.clock.NewTimer(15*time.Second, "agentdesktop", "close_cleanup_timeout")
defer timer.Stop()
select {
case <-cleanupDone:
case <-timer.C:
p.logger.Warn(context.Background(), "timed out waiting for close cleanup")
}
return nil
}
// RecordActivity marks the desktop as having received user
// interaction, resetting the idle-recording timer.
func (p *portableDesktop) RecordActivity() {
p.lastDesktopActionAt.Store(p.clock.Now().UnixNano())
}
// runCmd executes a portabledesktop subcommand and returns combined
// output. The caller must have previously called ensureBinary.
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
start := time.Now()
//nolint:gosec // args are constructed by the caller, not user input.
cmd := p.execer.CommandContext(ctx, p.binPath, args...)
out, err := cmd.CombinedOutput()
elapsed := time.Since(start)
if err != nil {
p.logger.Warn(ctx, "portabledesktop command failed",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
slog.Error(err),
slog.F("output", string(out)),
)
return "", xerrors.Errorf("portabledesktop %s: %w: %s", args[0], err, string(out))
}
if elapsed > 5*time.Second {
p.logger.Warn(ctx, "portabledesktop command slow",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
)
} else {
p.logger.Debug(ctx, "portabledesktop command completed",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
)
}
return string(out), nil
}
// ensureBinary resolves the portabledesktop binary from PATH or the
// coder script bin directory. It must be called while p.mu is held.
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
if p.binPath != "" {
return nil
}
// 1. Check PATH.
if path, err := exec.LookPath("portabledesktop"); err == nil {
p.logger.Info(ctx, "found portabledesktop in PATH",
slog.F("path", path),
)
p.binPath = path
return nil
}
// 2. Check the coder script bin directory.
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
// On Windows, permission bits don't indicate executability,
// so accept any regular file.
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
p.logger.Info(ctx, "found portabledesktop in script bin directory",
slog.F("path", scriptBinPath),
)
p.binPath = scriptBinPath
return nil
}
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
slog.F("path", scriptBinPath),
slog.F("mode", info.Mode().String()),
)
}
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
}
// monitorRecordingIdle watches for desktop inactivity and stops the
// given recording when the idle timeout is reached.
func (p *portableDesktop) monitorRecordingIdle(ctx context.Context, rec *recordingProcess) {
timer := p.clock.NewTimer(idleTimeout, "agentdesktop", "recording_idle")
defer timer.Stop()
for {
select {
case <-timer.C:
lastNano := p.lastDesktopActionAt.Load()
lastAction := time.Unix(0, lastNano)
elapsed := p.clock.Since(lastAction)
if elapsed >= idleTimeout {
p.mu.Lock()
p.lockedStopRecordingProcess(context.Background(), rec, false)
p.mu.Unlock()
return
}
// Activity happened; reset with remaining budget.
timer.Reset(idleTimeout-elapsed, "agentdesktop", "recording_idle")
case <-rec.done:
return
case <-ctx.Done():
return
}
}
}
@@ -1,12 +0,0 @@
//go:build !windows
package agentdesktop
import "os"
// interruptRecordingProcess sends a SIGINT to the recording process
// for graceful shutdown. On Unix, os.Interrupt is delivered as
// SIGINT which lets the recorder finalize the MP4 container.
func interruptRecordingProcess(p *os.Process) error {
return p.Signal(os.Interrupt)
}
@@ -1,10 +0,0 @@
package agentdesktop
import "os"
// interruptRecordingProcess kills the recording process directly
// because os.Process.Signal(os.Interrupt) is not supported on
// Windows and returns an error without delivering a signal.
func interruptRecordingProcess(p *os.Process) error {
return p.Kill()
}
-88
View File
@@ -1,88 +0,0 @@
package agentmcp
import (
"errors"
"net/http"
"github.com/go-chi/chi/v5"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// API exposes MCP tool discovery and call proxying through the
// agent.
type API struct {
logger slog.Logger
manager *Manager
}
// NewAPI creates a new MCP API handler backed by the given
// manager.
func NewAPI(logger slog.Logger, manager *Manager) *API {
return &API{
logger: logger,
manager: manager,
}
}
// Routes returns the HTTP handler for MCP-related routes.
func (api *API) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/tools", api.handleListTools)
r.Post("/call-tool", api.handleCallTool)
return r
}
// handleListTools returns the cached MCP tool definitions,
// optionally refreshing them first if ?refresh=true is set.
func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Allow callers to force a tool re-scan before listing.
if r.URL.Query().Get("refresh") == "true" {
if err := api.manager.RefreshTools(ctx); err != nil {
api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err))
}
}
tools := api.manager.Tools()
// Ensure non-nil so JSON serialization returns [] not null.
if tools == nil {
tools = []workspacesdk.MCPToolInfo{}
}
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListMCPToolsResponse{
Tools: tools,
})
}
// handleCallTool proxies a tool invocation to the appropriate
// MCP server based on the tool name prefix.
func (api *API) handleCallTool(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req workspacesdk.CallMCPToolRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
resp, err := api.manager.CallTool(ctx, req)
if err != nil {
status := http.StatusBadGateway
if errors.Is(err, ErrInvalidToolName) {
status = http.StatusBadRequest
} else if errors.Is(err, ErrUnknownServer) {
status = http.StatusNotFound
}
httpapi.Write(ctx, rw, status, codersdk.Response{
Message: "MCP tool call failed.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
-115
View File
@@ -1,115 +0,0 @@
package agentmcp
import (
"encoding/json"
"os"
"slices"
"strings"
"golang.org/x/xerrors"
)
// ServerConfig describes a single MCP server parsed from a .mcp.json file.
type ServerConfig struct {
Name string `json:"name"`
Transport string `json:"type"`
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
}
// mcpConfigFile mirrors the on-disk .mcp.json schema.
type mcpConfigFile struct {
MCPServers map[string]json.RawMessage `json:"mcpServers"`
}
// mcpServerEntry is a single server block inside mcpServers.
type mcpServerEntry struct {
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
Type string `json:"type"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
}
// ParseConfig reads a .mcp.json file at path and returns the declared
// MCP servers sorted by name. It returns an empty slice when the
// mcpServers key is missing or empty.
func ParseConfig(path string) ([]ServerConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, xerrors.Errorf("read mcp config %q: %w", path, err)
}
var cfg mcpConfigFile
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, xerrors.Errorf("parse mcp config %q: %w", path, err)
}
if len(cfg.MCPServers) == 0 {
return []ServerConfig{}, nil
}
servers := make([]ServerConfig, 0, len(cfg.MCPServers))
for name, raw := range cfg.MCPServers {
var entry mcpServerEntry
if err := json.Unmarshal(raw, &entry); err != nil {
return nil, xerrors.Errorf("parse server %q in %q: %w", name, path, err)
}
if strings.Contains(name, ToolNameSep) || strings.HasPrefix(name, "_") || strings.HasSuffix(name, "_") {
return nil, xerrors.Errorf("server name %q in %q contains reserved separator %q or leading/trailing underscore", name, path, ToolNameSep)
}
transport := inferTransport(entry)
if transport == "" {
return nil, xerrors.Errorf("server %q in %q has no command or url", name, path)
}
resolveEnvVars(entry.Env)
servers = append(servers, ServerConfig{
Name: name,
Transport: transport,
Command: entry.Command,
Args: entry.Args,
Env: entry.Env,
URL: entry.URL,
Headers: entry.Headers,
})
}
slices.SortFunc(servers, func(a, b ServerConfig) int {
return strings.Compare(a.Name, b.Name)
})
return servers, nil
}
// inferTransport determines the transport type for a server entry.
// An explicit "type" field takes priority; otherwise the presence
// of "command" implies stdio and "url" implies http.
func inferTransport(e mcpServerEntry) string {
if e.Type != "" {
return e.Type
}
if e.Command != "" {
return "stdio"
}
if e.URL != "" {
return "http"
}
return ""
}
// resolveEnvVars expands ${VAR} references in env map values
// using the current process environment.
func resolveEnvVars(env map[string]string) {
for k, v := range env {
env[k] = os.Expand(v, os.Getenv)
}
}
-254
View File
@@ -1,254 +0,0 @@
package agentmcp_test
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/x/agentmcp"
)
func TestParseConfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
content string
expected []agentmcp.ServerConfig
expectError bool
}{
{
name: "StdioServer",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"my-server": map[string]any{
"command": "npx",
"args": []string{"-y", "@example/mcp-server"},
"env": map[string]string{"FOO": "bar"},
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "my-server",
Transport: "stdio",
Command: "npx",
Args: []string{"-y", "@example/mcp-server"},
Env: map[string]string{"FOO": "bar"},
},
},
},
{
name: "HTTPServer",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"remote": map[string]any{
"url": "https://example.com/mcp",
"headers": map[string]string{"Authorization": "Bearer tok"},
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "remote",
Transport: "http",
URL: "https://example.com/mcp",
Headers: map[string]string{"Authorization": "Bearer tok"},
},
},
},
{
name: "SSEServer",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"events": map[string]any{
"type": "sse",
"url": "https://example.com/sse",
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "events",
Transport: "sse",
URL: "https://example.com/sse",
},
},
},
{
name: "ExplicitTypeOverridesInference",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"hybrid": map[string]any{
"command": "some-binary",
"type": "http",
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "hybrid",
Transport: "http",
Command: "some-binary",
},
},
},
{
name: "EnvVarPassthrough",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"srv": map[string]any{
"command": "run",
"env": map[string]string{"PLAIN": "literal-value"},
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "srv",
Transport: "stdio",
Command: "run",
Env: map[string]string{"PLAIN": "literal-value"},
},
},
},
{
name: "EmptyMCPServers",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{},
}),
expected: []agentmcp.ServerConfig{},
},
{
name: "MalformedJSON",
content: `{not valid json`,
expectError: true,
},
{
name: "ServerNameContainsSeparator",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"bad__name": map[string]any{"command": "run"},
},
}),
expectError: true,
},
{
name: "ServerNameTrailingUnderscore",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"server_": map[string]any{"command": "run"},
},
}),
expectError: true,
},
{
name: "ServerNameLeadingUnderscore",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"_server": map[string]any{"command": "run"},
},
}),
expectError: true,
},
{
name: "EmptyTransport", content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"empty": map[string]any{},
},
}),
expectError: true,
},
{
name: "MissingMCPServersKey",
content: mustJSON(t, map[string]any{
"servers": map[string]any{},
}),
expected: []agentmcp.ServerConfig{},
},
{
name: "MultipleServersSortedByName",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"zeta": map[string]any{"command": "z"},
"alpha": map[string]any{"command": "a"},
"mu": map[string]any{"command": "m"},
},
}),
expected: []agentmcp.ServerConfig{
{Name: "alpha", Transport: "stdio", Command: "a"},
{Name: "mu", Transport: "stdio", Command: "m"},
{Name: "zeta", Transport: "stdio", Command: "z"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, ".mcp.json")
err := os.WriteFile(path, []byte(tt.content), 0o600)
require.NoError(t, err)
got, err := agentmcp.ParseConfig(path)
if tt.expectError {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.expected, got)
})
}
}
// TestParseConfig_EnvVarInterpolation verifies that ${VAR} references
// in env values are resolved from the process environment. This test
// cannot be parallel because t.Setenv is incompatible with t.Parallel.
func TestParseConfig_EnvVarInterpolation(t *testing.T) {
t.Setenv("TEST_MCP_TOKEN", "secret123")
content := mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"srv": map[string]any{
"command": "run",
"env": map[string]string{"TOKEN": "${TEST_MCP_TOKEN}"},
},
},
})
dir := t.TempDir()
path := filepath.Join(dir, ".mcp.json")
err := os.WriteFile(path, []byte(content), 0o600)
require.NoError(t, err)
got, err := agentmcp.ParseConfig(path)
require.NoError(t, err)
require.Equal(t, []agentmcp.ServerConfig{
{
Name: "srv",
Transport: "stdio",
Command: "run",
Env: map[string]string{"TOKEN": "secret123"},
},
}, got)
}
func TestParseConfig_FileNotFound(t *testing.T) {
t.Parallel()
_, err := agentmcp.ParseConfig(filepath.Join(t.TempDir(), "nonexistent.json"))
require.Error(t, err)
}
// mustJSON marshals v to a JSON string, failing the test on error.
func mustJSON(t *testing.T, v any) string {
t.Helper()
data, err := json.Marshal(v)
require.NoError(t, err)
return string(data)
}
-474
View File
@@ -1,474 +0,0 @@
package agentmcp
import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"slices"
"strings"
"sync"
"time"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// ToolNameSep separates the server name from the original tool name
// in prefixed tool names. Double underscore avoids collisions with
// tool names that may contain single underscores.
const ToolNameSep = "__"
// connectTimeout bounds how long we wait for a single MCP server
// to start its transport and complete initialization.
const connectTimeout = 30 * time.Second
// toolCallTimeout bounds how long a single tool invocation may
// take before being canceled.
const toolCallTimeout = 60 * time.Second
var (
// ErrInvalidToolName is returned when the tool name format
// is not "server__tool".
ErrInvalidToolName = xerrors.New("invalid tool name format")
// ErrUnknownServer is returned when no MCP server matches
// the prefix in the tool name.
ErrUnknownServer = xerrors.New("unknown MCP server")
)
// Manager manages connections to MCP servers discovered from a
// workspace's .mcp.json file. It caches the aggregated tool list
// and proxies tool calls to the appropriate server.
type Manager struct {
mu sync.RWMutex
logger slog.Logger
closed bool
servers map[string]*serverEntry // keyed by server name
tools []workspacesdk.MCPToolInfo
}
// serverEntry pairs a server config with its connected client.
type serverEntry struct {
config ServerConfig
client *client.Client
}
// NewManager creates a new MCP client manager.
func NewManager(logger slog.Logger) *Manager {
return &Manager{
logger: logger,
servers: make(map[string]*serverEntry),
}
}
// Connect reads MCP config files at the given absolute paths and
// connects to all configured servers. Failed servers are logged
// and skipped. Missing config files are silently skipped.
func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error {
var allConfigs []ServerConfig
for _, configPath := range mcpConfigFiles {
configs, err := ParseConfig(configPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
continue
}
m.logger.Warn(ctx, "failed to parse MCP config",
slog.F("path", configPath),
slog.Error(err),
)
continue
}
allConfigs = append(allConfigs, configs...)
}
// Deduplicate by server name; first occurrence wins.
seen := make(map[string]struct{})
deduped := make([]ServerConfig, 0, len(allConfigs))
for _, cfg := range allConfigs {
if _, ok := seen[cfg.Name]; ok {
continue
}
seen[cfg.Name] = struct{}{}
deduped = append(deduped, cfg)
}
allConfigs = deduped
if len(allConfigs) == 0 {
return nil
}
// Connect to servers in parallel without holding the
// lock, since each connectServer call may block on
// network I/O for up to connectTimeout.
type connectedServer struct {
name string
config ServerConfig
client *client.Client
}
var (
mu sync.Mutex
connected []connectedServer
)
var eg errgroup.Group
for _, cfg := range allConfigs {
eg.Go(func() error {
c, err := m.connectServer(ctx, cfg)
if err != nil {
m.logger.Warn(ctx, "skipping MCP server",
slog.F("server", cfg.Name),
slog.F("transport", cfg.Transport),
slog.Error(err),
)
return nil // Don't fail the group.
}
mu.Lock()
connected = append(connected, connectedServer{
name: cfg.Name, config: cfg, client: c,
})
mu.Unlock()
return nil
})
}
_ = eg.Wait()
m.mu.Lock()
if m.closed {
m.mu.Unlock()
// Close the freshly-connected clients since we're
// shutting down.
for _, cs := range connected {
_ = cs.client.Close()
}
return xerrors.New("manager closed")
}
// Close previous connections to avoid leaking child
// processes on agent reconnect.
for _, entry := range m.servers {
_ = entry.client.Close()
}
m.servers = make(map[string]*serverEntry, len(connected))
for _, cs := range connected {
m.servers[cs.name] = &serverEntry{
config: cs.config,
client: cs.client,
}
}
m.mu.Unlock()
// Refresh tools outside the lock to avoid blocking
// concurrent reads during network I/O.
if err := m.RefreshTools(ctx); err != nil {
m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err))
}
return nil
}
// connectServer establishes a connection to a single MCP server
// and returns the connected client. It does not modify any Manager
// state.
func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) {
tr, err := createTransport(cfg)
if err != nil {
return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err)
}
c := client.NewClient(tr)
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
// Use the parent ctx (not connectCtx) so the subprocess outlives
// the connect/initialize handshake. connectCtx bounds only the
// Initialize call below. The subprocess is cleaned up when the
// Manager is closed or ctx is canceled.
if err := c.Start(ctx); err != nil {
_ = c.Close()
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
}
_, err = c.Initialize(connectCtx, mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "coder-agent",
Version: buildinfo.Version(),
},
},
})
if err != nil {
_ = c.Close()
return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err)
}
return c, nil
}
// createTransport builds the mcp-go transport for a server config.
func createTransport(cfg ServerConfig) (transport.Interface, error) {
switch cfg.Transport {
case "stdio":
return transport.NewStdio(
cfg.Command,
buildEnv(cfg.Env),
cfg.Args...,
), nil
case "http", "":
return transport.NewStreamableHTTP(
cfg.URL,
transport.WithHTTPHeaders(cfg.Headers),
)
case "sse":
return transport.NewSSE(
cfg.URL,
transport.WithHeaders(cfg.Headers),
)
default:
return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport)
}
}
// buildEnv merges the current process environment with explicit
// overrides, returning the result as KEY=VALUE strings suitable
// for the stdio transport.
func buildEnv(explicit map[string]string) []string {
env := os.Environ()
if len(explicit) == 0 {
return env
}
// Index existing env so explicit keys can override in-place.
existing := make(map[string]int, len(env))
for i, kv := range env {
if k, _, ok := strings.Cut(kv, "="); ok {
existing[k] = i
}
}
for k, v := range explicit {
entry := k + "=" + v
if idx, ok := existing[k]; ok {
env[idx] = entry
} else {
env = append(env, entry)
}
}
return env
}
// Tools returns the cached tool list. Thread-safe.
func (m *Manager) Tools() []workspacesdk.MCPToolInfo {
m.mu.RLock()
defer m.mu.RUnlock()
return slices.Clone(m.tools)
}
// CallTool proxies a tool call to the appropriate MCP server.
func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
serverName, originalName, err := splitToolName(req.ToolName)
if err != nil {
return workspacesdk.CallMCPToolResponse{}, err
}
m.mu.RLock()
entry, ok := m.servers[serverName]
m.mu.RUnlock()
if !ok {
return workspacesdk.CallMCPToolResponse{}, xerrors.Errorf("%w: %q", ErrUnknownServer, serverName)
}
callCtx, cancel := context.WithTimeout(ctx, toolCallTimeout)
defer cancel()
result, err := entry.client.CallTool(callCtx, mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: originalName,
Arguments: req.Arguments,
},
})
if err != nil {
return workspacesdk.CallMCPToolResponse{}, xerrors.Errorf("call tool %q on %q: %w", originalName, serverName, err)
}
return convertResult(result), nil
}
// splitToolName extracts the server name and original tool name
// from a prefixed tool name like "server__tool".
func splitToolName(prefixed string) (serverName, toolName string, err error) {
server, tool, ok := strings.Cut(prefixed, ToolNameSep)
if !ok || server == "" || tool == "" {
return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed)
}
return server, tool, nil
}
// convertResult translates an MCP CallToolResult into a
// workspacesdk.CallMCPToolResponse. It iterates over content
// items and maps each recognized type.
func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse {
if result == nil {
return workspacesdk.CallMCPToolResponse{}
}
var content []workspacesdk.MCPToolContent
for _, item := range result.Content {
switch c := item.(type) {
case mcp.TextContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "text",
Text: c.Text,
})
case mcp.ImageContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "image",
Data: c.Data,
MediaType: c.MIMEType,
})
case mcp.AudioContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "audio",
Data: c.Data,
MediaType: c.MIMEType,
})
case mcp.EmbeddedResource:
content = append(content, workspacesdk.MCPToolContent{
Type: "resource",
Text: fmt.Sprintf("[embedded resource: %T]", c.Resource),
})
case mcp.ResourceLink:
content = append(content, workspacesdk.MCPToolContent{
Type: "resource",
Text: fmt.Sprintf("[resource link: %s]", c.URI),
})
default:
content = append(content, workspacesdk.MCPToolContent{
Type: "text",
Text: fmt.Sprintf("[unsupported content type: %T]", item),
})
}
}
return workspacesdk.CallMCPToolResponse{
Content: content,
IsError: result.IsError,
}
}
// RefreshTools re-fetches tool lists from all connected servers
// in parallel and rebuilds the cache. On partial failure, tools
// from servers that responded successfully are merged with the
// existing cached tools for servers that failed, so a single
// dead server doesn't block updates from healthy ones.
func (m *Manager) RefreshTools(ctx context.Context) error {
// Snapshot servers under read lock.
m.mu.RLock()
servers := make(map[string]*serverEntry, len(m.servers))
for k, v := range m.servers {
servers[k] = v
}
m.mu.RUnlock()
// Fetch tool lists in parallel without holding any lock.
type serverTools struct {
name string
tools []workspacesdk.MCPToolInfo
}
var (
mu sync.Mutex
results []serverTools
failed []string
errs []error
)
var eg errgroup.Group
for name, entry := range servers {
eg.Go(func() error {
listCtx, cancel := context.WithTimeout(ctx, connectTimeout)
result, err := entry.client.ListTools(listCtx, mcp.ListToolsRequest{})
cancel()
if err != nil {
m.logger.Warn(ctx, "failed to list tools from MCP server",
slog.F("server", name),
slog.Error(err),
)
mu.Lock()
errs = append(errs, xerrors.Errorf("list tools from %q: %w", name, err))
failed = append(failed, name)
mu.Unlock()
return nil
}
var tools []workspacesdk.MCPToolInfo
for _, tool := range result.Tools {
tools = append(tools, workspacesdk.MCPToolInfo{
ServerName: name,
Name: name + ToolNameSep + tool.Name,
Description: tool.Description,
Schema: tool.InputSchema.Properties,
Required: tool.InputSchema.Required,
})
}
mu.Lock()
results = append(results, serverTools{name: name, tools: tools})
mu.Unlock()
return nil
})
}
_ = eg.Wait()
// Build the new tool list. For servers that failed, preserve
// their tools from the existing cache so a single dead server
// doesn't remove healthy tools.
var merged []workspacesdk.MCPToolInfo
for _, st := range results {
merged = append(merged, st.tools...)
}
if len(failed) > 0 {
failedSet := make(map[string]struct{}, len(failed))
for _, f := range failed {
failedSet[f] = struct{}{}
}
m.mu.RLock()
for _, t := range m.tools {
if _, ok := failedSet[t.ServerName]; ok {
merged = append(merged, t)
}
}
m.mu.RUnlock()
}
slices.SortFunc(merged, func(a, b workspacesdk.MCPToolInfo) int {
return strings.Compare(a.Name, b.Name)
})
m.mu.Lock()
m.tools = merged
m.mu.Unlock()
return errors.Join(errs...)
}
// Close terminates all MCP server connections and child
// processes.
func (m *Manager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.closed = true
var errs []error
for _, entry := range m.servers {
errs = append(errs, entry.client.Close())
}
m.servers = make(map[string]*serverEntry)
m.tools = nil
return errors.Join(errs...)
}
-316
View File
@@ -1,316 +0,0 @@
package agentmcp
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"testing"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
func TestSplitToolName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantServer string
wantTool string
wantErr bool
}{
{
name: "Valid",
input: "server__tool",
wantServer: "server",
wantTool: "tool",
},
{
name: "ValidWithUnderscoresInTool",
input: "server__my_tool",
wantServer: "server",
wantTool: "my_tool",
},
{
name: "MissingSeparator",
input: "servertool",
wantErr: true,
},
{
name: "EmptyServer",
input: "__tool",
wantErr: true,
},
{
name: "EmptyTool",
input: "server__",
wantErr: true,
},
{
name: "JustSeparator",
input: "__",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server, tool, err := splitToolName(tt.input)
if tt.wantErr {
require.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidToolName)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantServer, server)
assert.Equal(t, tt.wantTool, tool)
})
}
}
func TestConvertResult(t *testing.T) {
t.Parallel()
tests := []struct {
name string
// input is a pointer so we can test nil.
input *mcp.CallToolResult
want workspacesdk.CallMCPToolResponse
}{
{
name: "NilInput",
input: nil,
want: workspacesdk.CallMCPToolResponse{},
},
{
name: "TextContent",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "hello"},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "text", Text: "hello"},
},
},
},
{
name: "ImageContent",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.ImageContent{
Type: "image",
Data: "base64data",
MIMEType: "image/png",
},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "image", Data: "base64data", MediaType: "image/png"},
},
},
},
{
name: "AudioContent",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.AudioContent{
Type: "audio",
Data: "base64audio",
MIMEType: "audio/mp3",
},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "audio", Data: "base64audio", MediaType: "audio/mp3"},
},
},
},
{
name: "IsErrorPropagation",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "fail"},
},
IsError: true,
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "text", Text: "fail"},
},
IsError: true,
},
},
{
name: "MultipleContentItems",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "caption"},
mcp.ImageContent{
Type: "image",
Data: "imgdata",
MIMEType: "image/jpeg",
},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "text", Text: "caption"},
{Type: "image", Data: "imgdata", MediaType: "image/jpeg"},
},
},
},
{
name: "ResourceLink",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.ResourceLink{
Type: "resource_link",
URI: "file:///tmp/test.txt",
},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "resource", Text: "[resource link: file:///tmp/test.txt]"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := convertResult(tt.input)
assert.Equal(t, tt.want, got)
})
}
}
// TestConnectServer_StdioProcessSurvivesConnect verifies that a stdio MCP
// server subprocess remains alive after connectServer returns. This is a
// regression test for a bug where the subprocess was tied to a short-lived
// connectCtx and killed as soon as the context was canceled.
func TestConnectServer_StdioProcessSurvivesConnect(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
// Child process: act as a minimal MCP server over stdio.
runFakeMCPServer()
return
}
// Get the path to the test binary so we can re-exec ourselves
// as a fake MCP server subprocess.
testBin, err := os.Executable()
require.NoError(t, err)
cfg := ServerConfig{
Name: "fake",
Transport: "stdio",
Command: testBin,
Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"},
Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"},
}
ctx := testutil.Context(t, testutil.WaitLong)
m := &Manager{}
client, err := m.connectServer(ctx, cfg)
require.NoError(t, err, "connectServer should succeed")
t.Cleanup(func() { _ = client.Close() })
// At this point connectServer has returned and its internal
// connectCtx has been canceled. The subprocess must still be
// alive. Verify by listing tools (requires a live server).
listCtx, listCancel := context.WithTimeout(ctx, testutil.WaitShort)
defer listCancel()
result, err := client.ListTools(listCtx, mcp.ListToolsRequest{})
require.NoError(t, err, "ListTools should succeed — server must be alive after connect")
require.Len(t, result.Tools, 1)
assert.Equal(t, "echo", result.Tools[0].Name)
}
// runFakeMCPServer implements a minimal JSON-RPC / MCP server over
// stdin/stdout, just enough for initialize + tools/list.
func runFakeMCPServer() {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
line := scanner.Bytes()
var req struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id"`
Method string `json:"method"`
}
if err := json.Unmarshal(line, &req); err != nil {
continue
}
var resp any
switch req.Method {
case "initialize":
resp = map[string]any{
"jsonrpc": "2.0",
"id": req.ID,
"result": map[string]any{
"protocolVersion": "2025-03-26",
"capabilities": map[string]any{
"tools": map[string]any{},
},
"serverInfo": map[string]any{
"name": "fake-server",
"version": "0.0.1",
},
},
}
case "notifications/initialized":
// No response needed for notifications.
continue
case "tools/list":
resp = map[string]any{
"jsonrpc": "2.0",
"id": req.ID,
"result": map[string]any{
"tools": []map[string]any{
{
"name": "echo",
"description": "echoes input",
"inputSchema": map[string]any{
"type": "object",
"properties": map[string]any{},
},
},
},
},
}
default:
resp = map[string]any{
"jsonrpc": "2.0",
"id": req.ID,
"error": map[string]any{
"code": -32601,
"message": "method not found",
},
}
}
out, err := json.Marshal(resp)
if err != nil {
continue
}
_, _ = fmt.Fprintf(os.Stdout, "%s\n", out)
}
}
+19 -28
View File
@@ -3,13 +3,11 @@
"enabled": true,
"clientKind": "git",
"useIgnoreFile": true,
"defaultBranch": "main"
"defaultBranch": "main",
},
"files": {
// static/*.html are Go templates with {{ }} directives that
// Biome's HTML parser does not support.
"includes": ["**", "!**/pnpm-lock.yaml", "!**/static/*.html"],
"ignoreUnknown": true
"includes": ["**", "!**/pnpm-lock.yaml"],
"ignoreUnknown": true,
},
"linter": {
"rules": {
@@ -17,7 +15,7 @@
"noSvgWithoutTitle": "off",
"useButtonType": "off",
"useSemanticElements": "off",
"noStaticElementInteractions": "off"
"noStaticElementInteractions": "off",
},
"correctness": {
"noUnusedImports": "warn",
@@ -26,9 +24,9 @@
"noUnusedVariables": {
"level": "warn",
"options": {
"ignoreRestSiblings": true
}
}
"ignoreRestSiblings": true,
},
},
},
"style": {
"noNonNullAssertion": "off",
@@ -49,7 +47,7 @@
"paths": {
"react": {
"message": "React 19 no longer requires forwardRef. Use ref as a prop instead.",
"importNames": ["forwardRef"]
"importNames": ["forwardRef"],
},
// "@mui/material/Alert": "Use components/Alert/Alert instead.",
// "@mui/material/AlertTitle": "Use components/Alert/Alert instead.",
@@ -117,10 +115,10 @@
"@emotion/styled": "Use Tailwind CSS instead.",
// "@emotion/cache": "Use Tailwind CSS instead.",
// "components/Stack/Stack": "Use Tailwind flex utilities instead (e.g., <div className='flex flex-col gap-4'>).",
"lodash": "Use lodash/<name> instead."
}
}
}
"lodash": "Use lodash/<name> instead.",
},
},
},
},
"suspicious": {
"noArrayIndexKey": "off",
@@ -131,21 +129,14 @@
"noConsole": {
"level": "error",
"options": {
"allow": ["error", "info", "warn"]
}
}
"allow": ["error", "info", "warn"],
},
},
},
"complexity": {
"noImportantStyles": "off" // TODO: check and fix !important styles
}
}
"noImportantStyles": "off", // TODO: check and fix !important styles
},
},
},
"css": {
"parser": {
// Biome 2.3+ requires opt-in for @apply and other
// Tailwind directives.
"tailwindDirectives": true
}
},
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json",
}
+6 -33
View File
@@ -17,7 +17,6 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"gopkg.in/natefinch/lumberjack.v2"
@@ -273,14 +272,11 @@ func workspaceAgent() *serpent.Command {
logger.Info(ctx, "agent devcontainer detection not enabled")
}
reinitCtx, reinitCancel := context.WithCancel(ctx)
defer reinitCancel()
reinitEvents := agentsdk.WaitForReinitLoop(reinitCtx, logger, client)
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
var (
lastOwnerID uuid.UUID
lastErr error
mustExit bool
lastErr error
mustExit bool
)
for {
prometheusRegistry := prometheus.NewRegistry()
@@ -347,32 +343,9 @@ func workspaceAgent() *serpent.Command {
case <-ctx.Done():
logger.Info(ctx, "agent shutting down", slog.Error(context.Cause(ctx)))
mustExit = true
case event, ok := <-reinitEvents:
switch {
case !ok:
// Channel closed — the reinit loop exited
// (terminal 409 or context expired). Keep
// running the current agent until the parent
// context is canceled.
logger.Info(ctx, "reinit channel closed, running without reinit capability")
reinitEvents = nil
<-ctx.Done()
mustExit = true
case event.OwnerID != uuid.Nil && event.OwnerID == lastOwnerID:
// Duplicate reinit for same owner — already
// reinitialized. Cancel the reinit loop
// goroutine and keep the current agent.
logger.Info(ctx, "skipping redundant reinit, owner unchanged",
slog.F("owner_id", event.OwnerID))
reinitCancel()
reinitEvents = nil
<-ctx.Done()
mustExit = true
default:
lastOwnerID = event.OwnerID
logger.Info(ctx, "agent received instruction to reinitialize",
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
}
case event := <-reinitEvents:
logger.Info(ctx, "agent received instruction to reinitialize",
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
}
lastErr = agnt.Close()
+1 -1
View File
@@ -104,7 +104,7 @@ func (b *Builder) Build(inv *serpent.Invocation) (log slog.Logger, closeLog func
addSinkIfProvided := func(sinkFn func(io.Writer) slog.Sink, loc string) error {
switch loc {
case "", "/dev/null":
case "":
case "/dev/stdout":
sinks = append(sinks, sinkFn(inv.Stdout))
+1 -4
View File
@@ -173,10 +173,7 @@ func Start(t *testing.T, inv *serpent.Invocation) {
StartWithAssert(t, inv, nil)
}
// StartWithAssert starts the given invocation and calls assertCallback
// with the resulting error when the invocation completes. If assertCallback
// is nil, expected shutdown errors are silently tolerated.
func StartWithAssert(t *testing.T, inv *serpent.Invocation, assertCallback func(t *testing.T, err error)) {
func StartWithAssert(t *testing.T, inv *serpent.Invocation, assertCallback func(t *testing.T, err error)) { //nolint:revive
t.Helper()
closeCh := make(chan struct{})
+2
View File
@@ -173,6 +173,7 @@ func (selectModel) Init() tea.Cmd {
return nil
}
//nolint:revive // The linter complains about modifying 'm' but this is typical practice for bubbletea
func (m selectModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
@@ -462,6 +463,7 @@ func (multiSelectModel) Init() tea.Cmd {
return nil
}
//nolint:revive // For same reason as previous Update definition
func (m multiSelectModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
-3
View File
@@ -1401,9 +1401,6 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
// Setup our workspace agent connection.
config := workspacetraffic.Config{
AgentID: agent.ID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agent.Name,
BytesPerTick: bytesPerTick,
Duration: strategy.timeout,
TickInterval: tickInterval,
+1
View File
@@ -1414,6 +1414,7 @@ func tailLineStyle() pretty.Style {
return pretty.Style{pretty.Nop}
}
//nolint:unused
func SlimUnsupported(w io.Writer, cmd string) {
_, _ = fmt.Fprintf(w, "You are using a 'slim' build of Coder, which does not support the %s subcommand.\n", pretty.Sprint(cliui.DefaultStyles.Code, cmd))
_, _ = fmt.Fprintln(w, "")
+2 -13
View File
@@ -352,6 +352,8 @@ func TestScheduleOverride(t *testing.T) {
require.NoError(t, err, "invalid schedule")
ownerClient, _, _, ws := setupTestSchedule(t, sched)
now := time.Now()
// To avoid the likelihood of time-related flakes, only matching up to the hour.
expectedDeadline := now.In(loc).Add(10 * time.Hour).Format("2006-01-02T15:")
// When: we override the stop schedule
inv, root := clitest.New(t,
@@ -362,19 +364,6 @@ func TestScheduleOverride(t *testing.T) {
pty := ptytest.New(t).Attach(inv)
require.NoError(t, inv.Run())
// Fetch the workspace to get the actual deadline set by the
// server. Computing our own expected deadline from a separately
// captured time.Now() is racy: the CLI command calls time.Now()
// internally, and with the Asia/Kolkata +05:30 offset the hour
// boundary falls at :30 UTC minutes. A small delay between our
// time.Now() and the command's is enough to land in different
// hours.
updated, err := ownerClient.Workspace(context.Background(), ws[0].ID)
require.NoError(t, err)
require.False(t, updated.LatestBuild.Deadline.IsZero(), "deadline should be set after extend")
require.WithinDuration(t, now.Add(10*time.Hour), updated.LatestBuild.Deadline.Time, 5*time.Minute)
expectedDeadline := updated.LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)
// Then: the updated schedule should be shown
pty.ExpectMatch(ws[0].OwnerName + "/" + ws[0].Name)
pty.ExpectMatch(sched.Humanize())
+11 -4
View File
@@ -63,6 +63,8 @@ import (
"github.com/coder/coder/v2/cli/config"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/autobuild"
"github.com/coder/coder/v2/coderd/automations"
"github.com/coder/coder/v2/coderd/automations/cronscheduler"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/awsiamrds"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -305,6 +307,7 @@ func enablePrometheus(
}
options.ProvisionerdServerMetrics = provisionerdserverMetrics
//nolint:revive
return ServeHandler(
ctx, logger, promhttp.InstrumentMetricHandler(
options.PrometheusRegistry, promhttp.HandlerFor(options.PrometheusRegistry, promhttp.HandlerOpts{}),
@@ -1156,8 +1159,10 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
jobReaper.Start()
defer jobReaper.Close()
waitForProvisionerJobs := false
// Currently there is no way to ask the server to shut
// Evaluates cron-based automation triggers every minute.
automationCron := cronscheduler.New(ctx, logger.Named("automation_cron"), options.Database, quartz.NewReal(), &automations.ChatdAdapter{Server: coderAPI.ChatDaemon()})
defer automationCron.Close()
waitForProvisionerJobs := false // Currently there is no way to ask the server to shut
// itself down, so any exit signal will result in a non-zero
// exit of the server.
var exitErr error
@@ -1636,6 +1641,8 @@ var defaultCipherSuites = func() []uint16 {
// configureServerTLS returns the TLS config used for the Coderd server
// connections to clients. A logger is passed in to allow printing warning
// messages that do not block startup.
//
//nolint:revive
func configureServerTLS(ctx context.Context, logger slog.Logger, tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string, ciphers []string, allowInsecureCiphers bool) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
@@ -2052,6 +2059,7 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
return &params, nil
}
//nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive)
func configureGithubOAuth2(instrument *promoauth.Factory, params *githubOAuth2ConfigParams) (*coderd.GithubOAuth2Config, error) {
redirectURL, err := params.accessURL.Parse("/api/v2/users/oauth2/github/callback")
if err != nil {
@@ -2327,8 +2335,7 @@ func ConfigureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile stri
return ctx, nil, err
}
tlsClientConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
tlsClientConfig := &tls.Config{ //nolint:gosec
Certificates: certificates,
NextProtos: []string{"h2", "http/1.1"},
}
+1
View File
@@ -2123,6 +2123,7 @@ func TestServer_TelemetryDisable(t *testing.T) {
// Set the default telemetry to true (normally disabled in tests).
t.Setenv("CODER_TEST_TELEMETRY_DEFAULT_ENABLE", "true")
//nolint:paralleltest // No need to reinitialise the variable tt (Go version).
for _, tt := range []struct {
key string
val string
-31
View File
@@ -165,37 +165,6 @@ func TestSyncCommands_Golden(t *testing.T) {
clitest.TestGoldenFile(t, "TestSyncCommands_Golden/want_success", outBuf.Bytes(), nil)
})
t.Run("want_multiple_deps", func(t *testing.T) {
t.Parallel()
path, cleanup := setupSocketServer(t)
defer cleanup()
ctx := testutil.Context(t, testutil.WaitShort)
var outBuf bytes.Buffer
inv, _ := clitest.New(t, "exp", "sync", "want", "test-unit", "dep-1", "dep-2", "dep-3", "--socket-path", path)
inv.Stdout = &outBuf
inv.Stderr = &outBuf
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
// Verify all dependencies were registered by checking status.
outBuf.Reset()
inv, _ = clitest.New(t, "exp", "sync", "status", "test-unit", "--socket-path", path, "--output", "json")
inv.Stdout = &outBuf
inv.Stderr = &outBuf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
// The output should mention all three dependencies.
output := outBuf.String()
require.Contains(t, output, "dep-1")
require.Contains(t, output, "dep-2")
require.Contains(t, output, "dep-3")
})
t.Run("complete", func(t *testing.T) {
t.Parallel()
path, cleanup := setupSocketServer(t)
+8 -9
View File
@@ -11,16 +11,17 @@ import (
func (*RootCmd) syncWant(socketPath *string) *serpent.Command {
cmd := &serpent.Command{
Use: "want <unit> <depends-on> [depends-on...]",
Short: "Declare that a unit depends on other units completing before it can start",
Long: "Declare that a unit depends on one or more other units completing before it can start. The unit specified first will not start until all subsequent units have signaled that they have completed.",
Use: "want <unit> <depends-on>",
Short: "Declare that a unit depends on another unit completing before it can start",
Long: "Declare that a unit depends on another unit completing before it can start. The unit specified first will not start until the second has signaled that it has completed.",
Handler: func(i *serpent.Invocation) error {
ctx := i.Context()
if len(i.Args) < 2 {
return xerrors.New("at least two arguments are required: unit and one or more depends-on")
if len(i.Args) != 2 {
return xerrors.New("exactly two arguments are required: unit and depends-on")
}
dependentUnit := unit.ID(i.Args[0])
dependsOn := unit.ID(i.Args[1])
opts := []agentsocket.Option{}
if *socketPath != "" {
@@ -33,10 +34,8 @@ func (*RootCmd) syncWant(socketPath *string) *serpent.Command {
}
defer client.Close()
for _, dep := range i.Args[1:] {
if err := client.SyncWant(ctx, dependentUnit, unit.ID(dep)); err != nil {
return xerrors.Errorf("declare dependency failed: %w", err)
}
if err := client.SyncWant(ctx, dependentUnit, dependsOn); err != nil {
return xerrors.Errorf("declare dependency failed: %w", err)
}
cliui.Info(i.Stdout, "Success")
+2 -2
View File
@@ -828,7 +828,7 @@ func TestTemplateEdit(t *testing.T) {
"--require-active-version",
}
inv, root := clitest.New(t, cmdArgs...)
//nolint:gocritic // Using owner client is required for template editing.
//nolint
clitest.SetupConfig(t, client, root)
ctx := testutil.Context(t, testutil.WaitLong)
@@ -858,7 +858,7 @@ func TestTemplateEdit(t *testing.T) {
"--name", "something-new",
}
inv, root := clitest.New(t, cmdArgs...)
//nolint:gocritic // Using owner client is required for template editing.
//nolint
clitest.SetupConfig(t, client, root)
ctx := testutil.Context(t, testutil.WaitLong)
+1 -1
View File
@@ -16,7 +16,7 @@ SUBCOMMANDS:
ping Test agent socket connectivity and health
start Wait until all unit dependencies are satisfied
status Show unit status and dependency state
want Declare that a unit depends on other units completing before it
want Declare that a unit depends on another unit completing before it
can start
OPTIONS:
+5 -5
View File
@@ -1,13 +1,13 @@
coder v0.0.0-devel
USAGE:
coder exp sync want <unit> <depends-on> [depends-on...]
coder exp sync want <unit> <depends-on>
Declare that a unit depends on other units completing before it can start
Declare that a unit depends on another unit completing before it can start
Declare that a unit depends on one or more other units completing before it
can start. The unit specified first will not start until all subsequent units
have signaled that they have completed.
Declare that a unit depends on another unit completing before it can start.
The unit specified first will not start until the second has signaled that it
has completed.
———
Run `coder --help` for a list of global options.
+2 -4
View File
@@ -17,8 +17,7 @@
"name": "owner",
"display_name": "Owner"
}
],
"has_ai_seat": false
]
},
{
"id": "==========[second user ID]==========",
@@ -32,7 +31,6 @@
"organization_ids": [
"===========[first org ID]==========="
],
"roles": [],
"has_ai_seat": false
"roles": []
}
]
+2 -7
View File
@@ -857,18 +857,13 @@ aibridgeproxy:
# Comma-separated list of AI provider domains for which HTTPS traffic will be
# decrypted and routed through AI Bridge. Requests to other domains will be
# tunneled directly without decryption. Supported domains: api.anthropic.com,
# api.openai.com, api.individual.githubcopilot.com,
# api.business.githubcopilot.com, api.enterprise.githubcopilot.com, chatgpt.com.
# (default:
# api.anthropic.com,api.openai.com,api.individual.githubcopilot.com,api.business.githubcopilot.com,api.enterprise.githubcopilot.com,chatgpt.com,
# api.openai.com, api.individual.githubcopilot.com.
# (default: api.anthropic.com,api.openai.com,api.individual.githubcopilot.com,
# type: string-array)
domain_allowlist:
- api.anthropic.com
- api.openai.com
- api.individual.githubcopilot.com
- api.business.githubcopilot.com
- api.enterprise.githubcopilot.com
- chatgpt.com
# URL of an upstream HTTP proxy to chain tunneled (non-allowlisted) requests
# through. Format: http://[user:pass@]host:port or https://[user:pass@]host:port.
# (default: <unset>, type: string)
+1 -1
View File
@@ -85,7 +85,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
AgentName: a.AgentName,
Type: connectionType,
Code: code,
IP: logIP,
Ip: logIP,
ConnectionID: uuid.NullUUID{
UUID: connectionID,
Valid: true,
+2 -1
View File
@@ -101,6 +101,7 @@ func TestConnectionLog(t *testing.T) {
reason: "because error says so",
},
}
//nolint:paralleltest // No longer necessary to reinitialise the variable tt.
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
@@ -152,7 +153,7 @@ func TestConnectionLog(t *testing.T) {
Int32: tt.status,
Valid: *tt.action == agentproto.Connection_DISCONNECT,
},
IP: expectedIP,
Ip: expectedIP,
Type: agentProtoConnectionTypeToConnectionLog(t, *tt.typ),
DisconnectReason: sql.NullString{
String: tt.reason,
-3
View File
@@ -3,7 +3,6 @@ package agentapi
import (
"context"
"fmt"
"strings"
"time"
"github.com/google/uuid"
@@ -61,8 +60,6 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
}
)
for _, md := range req.Metadata {
md.Result.Value = strings.TrimSpace(md.Result.Value)
md.Result.Error = strings.TrimSpace(md.Result.Error)
metadataError := md.Result.Error
allKeysLen += len(md.Key)
+4 -32
View File
@@ -57,44 +57,16 @@ func TestBatchUpdateMetadata(t *testing.T) {
CollectedAt: timestamppb.New(now.Add(-3 * time.Second)),
Age: 3,
Value: "",
Error: "\t uncool error ",
Error: "uncool value",
},
},
},
}
batchSize := len(req.Metadata)
// This test sends 2 metadata entries (one clean, one with
// whitespace padding). With batch size 2 we expect exactly
// 1 capacity flush. The matcher verifies that stored values
// are trimmed while clean values pass through unchanged.
expectedValues := map[string]string{
"awesome key": "awesome value",
"uncool key": "",
}
expectedErrors := map[string]string{
"awesome key": "",
"uncool key": "uncool error",
}
// This test sends 2 metadata entries. With batch size 2, we expect
// exactly 1 capacity flush.
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(
gomock.Any(),
gomock.Cond(func(arg database.BatchUpdateWorkspaceAgentMetadataParams) bool {
if len(arg.Key) != len(expectedValues) {
return false
}
for i, key := range arg.Key {
expVal, ok := expectedValues[key]
if !ok || arg.Value[i] != expVal {
return false
}
expErr, ok := expectedErrors[key]
if !ok || arg.Error[i] != expErr {
return false
}
}
return true
}),
).
BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).
Return(nil).
Times(1)
+10 -39
View File
@@ -6,47 +6,18 @@ import (
"strings"
)
// HeaderCoderToken is a header set by clients opting into BYOK
// (Bring Your Own Key) mode. It carries the Coder token so
// that Authorization and X-Api-Key can carry the user's own LLM
// credentials. When present, AI Bridge forwards the user's LLM
// headers unchanged instead of injecting the centralized key.
//
// The AI Bridge proxy also sets this header automatically for clients
// that use per-user LLM credentials but cannot set custom headers.
const HeaderCoderToken = "X-Coder-AI-Governance-Token" //nolint:gosec // This is a header name, not a credential.
// HeaderCoderAuth is an internal header used to pass the Coder token
// from AI Proxy to AI Bridge for authentication. This header is stripped
// by AI Bridge before forwarding requests to upstream providers.
const HeaderCoderAuth = "X-Coder-Token"
// HeaderCoderRequestID is a header set by aibridgeproxyd on each
// request forwarded to aibridged for cross-service log correlation.
const HeaderCoderRequestID = "X-Coder-AI-Governance-Request-Id"
// Copilot provider.
const (
ProviderCopilotBusiness = "copilot-business"
HostCopilotBusiness = "api.business.githubcopilot.com"
ProviderCopilotEnterprise = "copilot-enterprise"
HostCopilotEnterprise = "api.enterprise.githubcopilot.com"
)
// ChatGPT provider.
const (
ProviderChatGPT = "chatgpt"
HostChatGPT = "chatgpt.com"
BaseURLChatGPT = "https://" + HostChatGPT + "/backend-api/codex"
)
// IsBYOK reports whether the request is using BYOK mode, determined
// by the presence of the X-Coder-AI-Governance-Token header.
func IsBYOK(header http.Header) bool {
return strings.TrimSpace(header.Get(HeaderCoderToken)) != ""
}
// ExtractAuthToken extracts a token from HTTP headers.
// It checks the BYOK header first (set by clients opting into BYOK),
// then falls back to Authorization: Bearer and X-Api-Key for direct
// centralized mode. If none are present, an empty string is returned.
// ExtractAuthToken extracts an authorization token from HTTP headers.
// It checks X-Coder-Token first (set by AI Proxy), then falls back
// to Authorization header (Bearer token) and X-Api-Key header, which represent
// the different ways clients authenticate against AI providers.
// If none are present, an empty string is returned.
func ExtractAuthToken(header http.Header) string {
if token := strings.TrimSpace(header.Get(HeaderCoderToken)); token != "" {
if token := strings.TrimSpace(header.Get(HeaderCoderAuth)); token != "" {
return token
}
if auth := strings.TrimSpace(header.Get("Authorization")); auth != "" {
+14 -326
View File
@@ -84,34 +84,6 @@ const docTemplate = `{
}
}
},
"/aibridge/clients": {
"get": {
"produces": [
"application/json"
],
"tags": [
"AI Bridge"
],
"summary": "List AI Bridge clients",
"operationId": "list-ai-bridge-clients",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/aibridge/interceptions": {
"get": {
"produces": [
@@ -242,58 +214,6 @@ const docTemplate = `{
]
}
},
"/aibridge/sessions/{session_id}": {
"get": {
"produces": [
"application/json"
],
"tags": [
"AI Bridge"
],
"summary": "Get AI Bridge session threads",
"operationId": "get-ai-bridge-session-threads",
"parameters": [
{
"type": "string",
"description": "Session ID (client_session_id or interception UUID)",
"name": "session_id",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Thread pagination cursor (forward/older)",
"name": "after_id",
"in": "query"
},
{
"type": "string",
"description": "Thread pagination cursor (backward/newer)",
"name": "before_id",
"in": "query"
},
{
"type": "integer",
"description": "Number of threads per page (default 50)",
"name": "limit",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsResponse"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/appearance": {
"get": {
"produces": [
@@ -10205,26 +10125,12 @@ const docTemplate = `{
],
"summary": "Get workspace agent reinitialization",
"operationId": "get-workspace-agent-reinitialization",
"parameters": [
{
"type": "boolean",
"description": "Opt in to durable reinit checks",
"name": "wait",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
}
},
"409": {
"description": "Conflict",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
}
},
"security": [
@@ -12661,16 +12567,11 @@ const docTemplate = `{
"agentsdk.ReinitializationEvent": {
"type": "object",
"properties": {
"owner_id": {
"type": "string",
"format": "uuid"
},
"reason": {
"$ref": "#/definitions/agentsdk.ReinitializationReason"
},
"workspace_id": {
"type": "string",
"format": "uuid"
"workspaceID": {
"type": "string"
}
}
},
@@ -12774,29 +12675,6 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeAgenticAction": {
"type": "object",
"properties": {
"model": {
"type": "string"
},
"thinking": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeModelThought"
}
},
"token_usage": {
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeToolCall"
}
}
}
},
"codersdk.AIBridgeAnthropicConfig": {
"type": "object",
"properties": {
@@ -12913,9 +12791,6 @@ const docTemplate = `{
"provider": {
"type": "string"
},
"provider_name": {
"type": "string"
},
"started_at": {
"type": "string",
"format": "date-time"
@@ -12968,14 +12843,6 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeModelThought": {
"type": "object",
"properties": {
"text": {
"type": "string"
}
}
},
"codersdk.AIBridgeOpenAIConfig": {
"type": "object",
"properties": {
@@ -13075,91 +12942,9 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeSessionThreadsResponse": {
"type": "object",
"properties": {
"client": {
"type": "string"
},
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string"
},
"initiator": {
"$ref": "#/definitions/codersdk.MinimalUser"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"models": {
"type": "array",
"items": {
"type": "string"
}
},
"page_ended_at": {
"type": "string",
"format": "date-time"
},
"page_started_at": {
"type": "string",
"format": "date-time"
},
"providers": {
"type": "array",
"items": {
"type": "string"
}
},
"started_at": {
"type": "string",
"format": "date-time"
},
"threads": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeThread"
}
},
"token_usage_summary": {
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
}
}
},
"codersdk.AIBridgeSessionThreadsTokenUsage": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"input_tokens": {
"type": "integer"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"output_tokens": {
"type": "integer"
}
}
},
"codersdk.AIBridgeSessionTokenUsageSummary": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"input_tokens": {
"type": "integer"
},
@@ -13168,50 +12953,9 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeThread": {
"type": "object",
"properties": {
"agentic_actions": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
}
},
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"model": {
"type": "string"
},
"prompt": {
"type": "string"
},
"provider": {
"type": "string"
},
"started_at": {
"type": "string",
"format": "date-time"
},
"token_usage": {
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
}
}
},
"codersdk.AIBridgeTokenUsage": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"created_at": {
"type": "string",
"format": "date-time"
@@ -13239,42 +12983,6 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeToolCall": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"injected": {
"type": "boolean"
},
"input": {
"type": "string"
},
"interception_id": {
"type": "string",
"format": "uuid"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"provider_response_id": {
"type": "string"
},
"server_url": {
"type": "string"
},
"tool": {
"type": "string"
}
}
},
"codersdk.AIBridgeToolUsage": {
"type": "object",
"properties": {
@@ -13476,6 +13184,11 @@ const docTemplate = `{
"audit_log:*",
"audit_log:create",
"audit_log:read",
"automation:*",
"automation:create",
"automation:delete",
"automation:read",
"automation:update",
"boundary_usage:*",
"boundary_usage:delete",
"boundary_usage:read",
@@ -13685,6 +13398,11 @@ const docTemplate = `{
"APIKeyScopeAuditLogAll",
"APIKeyScopeAuditLogCreate",
"APIKeyScopeAuditLogRead",
"APIKeyScopeAutomationAll",
"APIKeyScopeAutomationCreate",
"APIKeyScopeAutomationDelete",
"APIKeyScopeAutomationRead",
"APIKeyScopeAutomationUpdate",
"APIKeyScopeBoundaryUsageAll",
"APIKeyScopeBoundaryUsageDelete",
"APIKeyScopeBoundaryUsageRead",
@@ -14175,9 +13893,6 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -14499,9 +14214,6 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -14589,17 +14301,6 @@ const docTemplate = `{
}
}
},
"codersdk.CreateFirstUserOnboardingInfo": {
"type": "object",
"properties": {
"newsletter_marketing": {
"type": "boolean"
},
"newsletter_releases": {
"type": "boolean"
}
}
},
"codersdk.CreateFirstUserRequest": {
"type": "object",
"required": [
@@ -14614,9 +14315,6 @@ const docTemplate = `{
"name": {
"type": "string"
},
"onboarding_info": {
"$ref": "#/definitions/codersdk.CreateFirstUserOnboardingInfo"
},
"password": {
"type": "string"
},
@@ -17738,10 +17436,6 @@ const docTemplate = `{
"$ref": "#/definitions/codersdk.SlimRole"
}
},
"has_ai_seat": {
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
"type": "boolean"
},
"is_service_account": {
"type": "boolean"
},
@@ -19096,6 +18790,7 @@ const docTemplate = `{
"assign_org_role",
"assign_role",
"audit_log",
"automation",
"boundary_usage",
"chat",
"connection_log",
@@ -19142,6 +18837,7 @@ const docTemplate = `{
"ResourceAssignOrgRole",
"ResourceAssignRole",
"ResourceAuditLog",
"ResourceAutomation",
"ResourceBoundaryUsage",
"ResourceChat",
"ResourceConnectionLog",
@@ -20538,10 +20234,6 @@ const docTemplate = `{
"type": "string",
"format": "email"
},
"has_ai_seat": {
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
"type": "boolean"
},
"id": {
"type": "string",
"format": "uuid"
@@ -21391,10 +21083,6 @@ const docTemplate = `{
"type": "string",
"format": "email"
},
"has_ai_seat": {
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
"type": "boolean"
},
"id": {
"type": "string",
"format": "uuid"
+14 -318
View File
@@ -65,30 +65,6 @@
}
}
},
"/aibridge/clients": {
"get": {
"produces": ["application/json"],
"tags": ["AI Bridge"],
"summary": "List AI Bridge clients",
"operationId": "list-ai-bridge-clients",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/aibridge/interceptions": {
"get": {
"produces": ["application/json"],
@@ -207,54 +183,6 @@
]
}
},
"/aibridge/sessions/{session_id}": {
"get": {
"produces": ["application/json"],
"tags": ["AI Bridge"],
"summary": "Get AI Bridge session threads",
"operationId": "get-ai-bridge-session-threads",
"parameters": [
{
"type": "string",
"description": "Session ID (client_session_id or interception UUID)",
"name": "session_id",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Thread pagination cursor (forward/older)",
"name": "after_id",
"in": "query"
},
{
"type": "string",
"description": "Thread pagination cursor (backward/newer)",
"name": "before_id",
"in": "query"
},
{
"type": "integer",
"description": "Number of threads per page (default 50)",
"name": "limit",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsResponse"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/appearance": {
"get": {
"produces": ["application/json"],
@@ -9038,26 +8966,12 @@
"tags": ["Agents"],
"summary": "Get workspace agent reinitialization",
"operationId": "get-workspace-agent-reinitialization",
"parameters": [
{
"type": "boolean",
"description": "Opt in to durable reinit checks",
"name": "wait",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
}
},
"409": {
"description": "Conflict",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
}
},
"security": [
@@ -11243,16 +11157,11 @@
"agentsdk.ReinitializationEvent": {
"type": "object",
"properties": {
"owner_id": {
"type": "string",
"format": "uuid"
},
"reason": {
"$ref": "#/definitions/agentsdk.ReinitializationReason"
},
"workspace_id": {
"type": "string",
"format": "uuid"
"workspaceID": {
"type": "string"
}
}
},
@@ -11352,29 +11261,6 @@
}
}
},
"codersdk.AIBridgeAgenticAction": {
"type": "object",
"properties": {
"model": {
"type": "string"
},
"thinking": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeModelThought"
}
},
"token_usage": {
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeToolCall"
}
}
}
},
"codersdk.AIBridgeAnthropicConfig": {
"type": "object",
"properties": {
@@ -11491,9 +11377,6 @@
"provider": {
"type": "string"
},
"provider_name": {
"type": "string"
},
"started_at": {
"type": "string",
"format": "date-time"
@@ -11546,14 +11429,6 @@
}
}
},
"codersdk.AIBridgeModelThought": {
"type": "object",
"properties": {
"text": {
"type": "string"
}
}
},
"codersdk.AIBridgeOpenAIConfig": {
"type": "object",
"properties": {
@@ -11653,91 +11528,9 @@
}
}
},
"codersdk.AIBridgeSessionThreadsResponse": {
"type": "object",
"properties": {
"client": {
"type": "string"
},
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string"
},
"initiator": {
"$ref": "#/definitions/codersdk.MinimalUser"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"models": {
"type": "array",
"items": {
"type": "string"
}
},
"page_ended_at": {
"type": "string",
"format": "date-time"
},
"page_started_at": {
"type": "string",
"format": "date-time"
},
"providers": {
"type": "array",
"items": {
"type": "string"
}
},
"started_at": {
"type": "string",
"format": "date-time"
},
"threads": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeThread"
}
},
"token_usage_summary": {
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
}
}
},
"codersdk.AIBridgeSessionThreadsTokenUsage": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"input_tokens": {
"type": "integer"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"output_tokens": {
"type": "integer"
}
}
},
"codersdk.AIBridgeSessionTokenUsageSummary": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"input_tokens": {
"type": "integer"
},
@@ -11746,50 +11539,9 @@
}
}
},
"codersdk.AIBridgeThread": {
"type": "object",
"properties": {
"agentic_actions": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
}
},
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"model": {
"type": "string"
},
"prompt": {
"type": "string"
},
"provider": {
"type": "string"
},
"started_at": {
"type": "string",
"format": "date-time"
},
"token_usage": {
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
}
}
},
"codersdk.AIBridgeTokenUsage": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"created_at": {
"type": "string",
"format": "date-time"
@@ -11817,42 +11569,6 @@
}
}
},
"codersdk.AIBridgeToolCall": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
},
"injected": {
"type": "boolean"
},
"input": {
"type": "string"
},
"interception_id": {
"type": "string",
"format": "uuid"
},
"metadata": {
"type": "object",
"additionalProperties": {}
},
"provider_response_id": {
"type": "string"
},
"server_url": {
"type": "string"
},
"tool": {
"type": "string"
}
}
},
"codersdk.AIBridgeToolUsage": {
"type": "object",
"properties": {
@@ -12046,6 +11762,11 @@
"audit_log:*",
"audit_log:create",
"audit_log:read",
"automation:*",
"automation:create",
"automation:delete",
"automation:read",
"automation:update",
"boundary_usage:*",
"boundary_usage:delete",
"boundary_usage:read",
@@ -12255,6 +11976,11 @@
"APIKeyScopeAuditLogAll",
"APIKeyScopeAuditLogCreate",
"APIKeyScopeAuditLogRead",
"APIKeyScopeAutomationAll",
"APIKeyScopeAutomationCreate",
"APIKeyScopeAutomationDelete",
"APIKeyScopeAutomationRead",
"APIKeyScopeAutomationUpdate",
"APIKeyScopeBoundaryUsageAll",
"APIKeyScopeBoundaryUsageDelete",
"APIKeyScopeBoundaryUsageRead",
@@ -12739,9 +12465,6 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -13042,9 +12765,6 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -13129,17 +12849,6 @@
}
}
},
"codersdk.CreateFirstUserOnboardingInfo": {
"type": "object",
"properties": {
"newsletter_marketing": {
"type": "boolean"
},
"newsletter_releases": {
"type": "boolean"
}
}
},
"codersdk.CreateFirstUserRequest": {
"type": "object",
"required": ["email", "password", "username"],
@@ -13150,9 +12859,6 @@
"name": {
"type": "string"
},
"onboarding_info": {
"$ref": "#/definitions/codersdk.CreateFirstUserOnboardingInfo"
},
"password": {
"type": "string"
},
@@ -16155,10 +15861,6 @@
"$ref": "#/definitions/codersdk.SlimRole"
}
},
"has_ai_seat": {
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
"type": "boolean"
},
"is_service_account": {
"type": "boolean"
},
@@ -17468,6 +17170,7 @@
"assign_org_role",
"assign_role",
"audit_log",
"automation",
"boundary_usage",
"chat",
"connection_log",
@@ -17514,6 +17217,7 @@
"ResourceAssignOrgRole",
"ResourceAssignRole",
"ResourceAuditLog",
"ResourceAutomation",
"ResourceBoundaryUsage",
"ResourceChat",
"ResourceConnectionLog",
@@ -18855,10 +18559,6 @@
"type": "string",
"format": "email"
},
"has_ai_seat": {
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
"type": "boolean"
},
"id": {
"type": "string",
"format": "uuid"
@@ -19651,10 +19351,6 @@
"type": "string",
"format": "email"
},
"has_ai_seat": {
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
"type": "boolean"
},
"id": {
"type": "string",
"format": "uuid"
-15
View File
@@ -582,20 +582,5 @@ func (api *API) createAPIKey(ctx context.Context, params apikey.CreateParams) (*
Value: sessionToken,
Path: "/",
HttpOnly: true,
// MaxAge is set so the browser persists the cookie to disk rather
// than keeping it in memory as a session cookie. Standalone PWAs
// (display: standalone) run in their own browser process, and
// mobile OSes kill that process when the app is swiped away —
// deleting in-memory cookies and forcing an unexpected login.
//
// We use a long static value (1 year) instead of the key's
// LifetimeSeconds because the server refreshes the key's
// ExpiresAt on activity but does not re-set the cookie. Tying
// MaxAge to the key lifetime would cause the cookie to expire
// client-side even when the server-side key is still valid.
//
// Security is not affected: the server validates ExpiresAt on
// every request regardless of the cookie's MaxAge.
MaxAge: int((365 * 24 * time.Hour).Seconds()),
}), &newkey, nil
}
-49
View File
@@ -394,55 +394,6 @@ func TestSessionExpiry(t *testing.T) {
}
}
// TestSessionCookieMaxAge verifies that the session cookie is a persistent
// cookie (has MaxAge set) rather than a session cookie. Standalone PWAs
// run in their own browser process and mobile OSes purge in-memory
// (session) cookies when that process is killed, so the cookie must be
// persisted to disk.
func TestSessionCookieMaxAge(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
client := coderdtest.New(t, nil)
// Create the first user (password-based login).
req := codersdk.CreateFirstUserRequest{
Email: "testuser@coder.com",
Username: "testuser",
Password: "SomeSecurePassword!",
}
_, err := client.CreateFirstUser(ctx, req)
require.NoError(t, err)
// Login via the raw HTTP endpoint so we can inspect the Set-Cookie header.
loginURL, err := client.URL.Parse("/api/v2/users/login")
require.NoError(t, err)
res, err := client.Request(ctx, http.MethodPost, loginURL.String(), codersdk.LoginWithPasswordRequest{
Email: req.Email,
Password: req.Password,
})
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusCreated, res.StatusCode)
oneYear := int((365 * 24 * time.Hour).Seconds())
var found bool
for _, cookie := range res.Cookies() {
if cookie.Name == codersdk.SessionTokenCookie {
// MaxAge should be set to a long value so the browser
// persists the cookie to disk. The server handles real
// expiry via the API key's ExpiresAt field.
require.Equal(t, oneYear, cookie.MaxAge,
"Session cookie MaxAge should be set to 1 year for disk persistence")
found = true
}
}
require.True(t, found, "session cookie should be present in login response")
}
func TestAPIKey_OK(t *testing.T) {
t.Parallel()
+1 -8
View File
@@ -26,11 +26,6 @@ import (
"github.com/coder/coder/v2/codersdk"
)
// Limit the count query to avoid a slow sequential scan due to joins
// on a large table. Set to 0 to disable capping (but also see the note
// in the SQL query).
const auditLogCountCap = 2000
// @Summary Get audit logs
// @ID get-audit-logs
// @Security CoderSessionToken
@@ -71,7 +66,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
countFilter.Username = ""
}
countFilter.CountCap = auditLogCountCap
// Use the same filters to count the number of audit logs
count, err := api.Database.CountAuditLogs(ctx, countFilter)
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
@@ -86,7 +81,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: []codersdk.AuditLog{},
Count: 0,
CountCap: auditLogCountCap,
})
return
}
@@ -104,7 +98,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: api.convertAuditLogs(ctx, dblogs),
Count: count,
CountCap: auditLogCountCap,
})
}
+1 -1
View File
@@ -220,7 +220,7 @@ func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) {
Type: string(v.Object.ResourceType),
AnyOrgOwner: v.Object.AnyOrgOwner,
}
if obj.Owner == codersdk.Me {
if obj.Owner == "me" {
obj.Owner = auth.ID
}
+56
View File
@@ -0,0 +1,56 @@
package automations
import (
"context"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
chatd "github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/codersdk"
)
// ChatdAdapter implements ChatCreator by delegating to chatd.Server.
type ChatdAdapter struct {
Server *chatd.Server
}
// CreateChat creates a new chat through the chatd server and returns
// the chat ID.
func (a *ChatdAdapter) CreateChat(ctx context.Context, opts CreateChatOptions) (uuid.UUID, error) {
var modelConfigID uuid.UUID
if opts.ModelConfigID.Valid {
modelConfigID = opts.ModelConfigID.UUID
}
labels := database.StringMap{}
for k, v := range opts.Labels {
labels[k] = v
}
chat, err := a.Server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: opts.OwnerID,
Title: opts.Title,
ModelConfigID: modelConfigID,
MCPServerIDs: opts.MCPServerIDs,
Labels: labels,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText(opts.Instructions),
},
})
if err != nil {
return uuid.Nil, err
}
return chat.ID, nil
}
// SendMessage appends a user message to an existing chat.
func (a *ChatdAdapter) SendMessage(ctx context.Context, chatID uuid.UUID, ownerID uuid.UUID, content string) error {
_, err := a.Server.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chatID,
CreatedBy: ownerID,
Content: []codersdk.ChatMessagePart{
codersdk.ChatMessageText(content),
},
})
return err
}
@@ -0,0 +1,196 @@
package cronscheduler
import (
"context"
"encoding/json"
"io"
"time"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/automations"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/pproflabel"
"github.com/coder/coder/v2/coderd/schedule/cron"
"github.com/coder/quartz"
)
const tickInterval = time.Minute
// New creates a background scheduler that evaluates cron-based
// automation triggers every minute. Only one replica runs the
// scheduler at a time via an advisory lock.
func New(ctx context.Context, logger slog.Logger, db database.Store, clk quartz.Clock, chat automations.ChatCreator) io.Closer {
closed := make(chan struct{})
ctx, cancelFunc := context.WithCancel(ctx)
//nolint:gocritic // System-level background job needs broad read access.
ctx = dbauthz.AsSystemRestricted(ctx)
inst := &instance{
cancel: cancelFunc,
closed: closed,
logger: logger,
db: db,
clk: clk,
chat: chat,
}
ticker := clk.NewTicker(tickInterval)
doTick := func(ctx context.Context, now time.Time) {
defer ticker.Reset(tickInterval)
inst.tick(ctx, now)
}
pproflabel.Go(ctx, pproflabel.Service("automation-cron"), func(ctx context.Context) {
defer close(closed)
defer ticker.Stop()
// Force an initial tick.
doTick(ctx, dbtime.Time(clk.Now()).UTC())
for {
select {
case <-ctx.Done():
return
case t := <-ticker.C:
ticker.Stop()
doTick(ctx, dbtime.Time(t).UTC())
}
}
})
return inst
}
type instance struct {
cancel context.CancelFunc
closed chan struct{}
logger slog.Logger
db database.Store
clk quartz.Clock
chat automations.ChatCreator
}
func (i *instance) Close() error {
i.cancel()
<-i.closed
return nil
}
// tick runs a single scheduler iteration. It acquires an advisory
// lock so that only one replica processes cron triggers at a time.
func (i *instance) tick(ctx context.Context, now time.Time) {
err := i.db.InTx(func(tx database.Store) error {
// Only one replica should evaluate cron triggers.
ok, err := tx.TryAcquireLock(ctx, database.LockIDAutomationCron)
if err != nil {
return err
}
if !ok {
i.logger.Debug(ctx, "unable to acquire automation cron lock, skipping")
return nil
}
triggers, err := tx.GetActiveCronTriggers(ctx)
if err != nil {
return err
}
for _, t := range triggers {
if err := ctx.Err(); err != nil {
return err
}
i.processTrigger(ctx, tx, t, now)
}
return nil
}, database.DefaultTXOptions().WithID("automation_cron"))
if err != nil && ctx.Err() == nil {
i.logger.Error(ctx, "automation cron tick failed", slog.Error(err))
}
}
// processTrigger evaluates a single cron trigger and fires it if
// the schedule indicates it is due.
func (i *instance) processTrigger(ctx context.Context, db database.Store, trigger database.GetActiveCronTriggersRow, now time.Time) {
logger := i.logger.With(
slog.F("trigger_id", trigger.ID),
slog.F("automation_id", trigger.AutomationID),
)
if !trigger.CronSchedule.Valid {
return
}
sched, err := cron.Standard(trigger.CronSchedule.String)
if err != nil {
logger.Warn(ctx, "invalid cron schedule on trigger",
slog.F("schedule", trigger.CronSchedule.String),
slog.Error(err),
)
return
}
// Determine the reference time for computing "next fire".
// If the trigger has never fired, use its creation time.
ref := trigger.CreatedAt
if trigger.LastTriggeredAt.Valid {
ref = trigger.LastTriggeredAt.Time
}
next := sched.Next(ref)
if next.After(now) {
// Not yet due.
return
}
// Build a synthetic payload for the cron event.
payload, _ := json.Marshal(map[string]any{
"trigger": "cron",
"schedule": trigger.CronSchedule.String,
"scheduled_at": next.UTC().Format(time.RFC3339),
"fired_at": now.UTC().Format(time.RFC3339),
})
// Resolve labels against the synthetic payload if configured.
var resolvedLabels map[string]string
if trigger.LabelPaths.Valid {
var labelPaths map[string]string
if err := json.Unmarshal(trigger.LabelPaths.RawMessage, &labelPaths); err == nil && len(labelPaths) > 0 {
resolvedLabels = automations.ResolveLabels(string(payload), labelPaths)
}
}
// Build the FireOptions from the trigger row.
fireOpts := automations.FireOptions{
AutomationID: trigger.AutomationID,
AutomationName: trigger.AutomationName,
AutomationStatus: trigger.AutomationStatus,
AutomationOwnerID: trigger.AutomationOwnerID,
AutomationInstructions: trigger.AutomationInstructions,
AutomationModelConfigID: trigger.AutomationModelConfigID,
AutomationMCPServerIDs: trigger.AutomationMcpServerIds,
AutomationAllowedTools: trigger.AutomationAllowedTools,
MaxChatCreatesPerHour: trigger.AutomationMaxChatCreatesPerHour,
MaxMessagesPerHour: trigger.AutomationMaxMessagesPerHour,
TriggerID: trigger.ID,
Payload: payload,
FilterMatched: true,
ResolvedLabels: resolvedLabels,
}
result := automations.Fire(ctx, logger, db, i.chat, fireOpts)
// Update last_triggered_at so this trigger is not re-fired
// until the next scheduled time.
updateErr := db.UpdateAutomationTriggerLastTriggeredAt(ctx, database.UpdateAutomationTriggerLastTriggeredAtParams{
LastTriggeredAt: now,
ID: trigger.ID,
})
if updateErr != nil {
logger.Error(ctx, "failed to update last_triggered_at", slog.Error(updateErr))
}
logger.Info(ctx, "fired cron automation trigger",
slog.F("status", result.Status),
slog.F("schedule", trigger.CronSchedule.String),
slog.F("next_after_ref", next),
)
}
@@ -0,0 +1,314 @@
package cronscheduler_test
import (
"context"
"database/sql"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/automations"
"github.com/coder/coder/v2/coderd/automations/cronscheduler"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
}
// awaitDoTick waits for the scheduler to complete its initial tick.
// It traps the quartz.Mock clock events in the same order the
// scheduler emits them: Now() → TickerReset.
func awaitDoTick(ctx context.Context, t *testing.T, clk *quartz.Mock) chan struct{} {
t.Helper()
ch := make(chan struct{})
trapNow := clk.Trap().Now()
trapReset := clk.Trap().TickerReset()
go func() {
defer close(ch)
defer trapReset.Close()
defer trapNow.Close()
// Wait for the initial Now() call that kicks off doTick.
trapNow.MustWait(ctx).MustRelease(ctx)
// Wait for the ticker reset that signals doTick completed.
trapReset.MustWait(ctx).MustRelease(ctx)
}()
return ch
}
// fakeChatCreator is a stub ChatCreator for tests. CreateChat
// returns a new UUID; SendMessage is a no-op.
type fakeChatCreator struct{}
func (fakeChatCreator) CreateChat(_ context.Context, _ automations.CreateChatOptions) (uuid.UUID, error) {
return uuid.New(), nil
}
func (fakeChatCreator) SendMessage(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ string) error {
return nil
}
// makeTrigger builds a GetActiveCronTriggersRow with sensible
// defaults. Fields can be overridden after creation.
func makeTrigger(schedule string, status string, createdAt time.Time) database.GetActiveCronTriggersRow {
return database.GetActiveCronTriggersRow{
ID: uuid.New(),
AutomationID: uuid.New(),
Type: "cron",
CronSchedule: sql.NullString{String: schedule, Valid: true},
CreatedAt: createdAt,
UpdatedAt: createdAt,
AutomationStatus: status,
AutomationOwnerID: uuid.New(),
AutomationMaxChatCreatesPerHour: 10,
AutomationMaxMessagesPerHour: 100,
}
}
//nolint:paralleltest // Uses LockIDAutomationCron advisory lock mock.
func TestScheduler(t *testing.T) {
t.Run("NoTriggers", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
clk := quartz.NewMock(t)
// The scheduler calls InTx; execute the function against
// the same mock so inner calls are recorded.
mDB.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(mDB)
},
).Times(1)
mDB.EXPECT().TryAcquireLock(gomock.Any(), int64(database.LockIDAutomationCron)).Return(true, nil)
mDB.EXPECT().GetActiveCronTriggers(gomock.Any()).Return(nil, nil)
done := awaitDoTick(ctx, t, clk)
scheduler := cronscheduler.New(ctx, testutil.Logger(t), mDB, clk, nil)
<-done
require.NoError(t, scheduler.Close())
})
t.Run("DueTriggerFires", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
clk := quartz.NewMock(t)
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
clk.Set(now).MustWait(ctx)
// Trigger was created 2 minutes ago with "every minute"
// schedule, so it should fire.
trigger := makeTrigger("* * * * *", "active", now.Add(-2*time.Minute))
mDB.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(mDB)
},
)
mDB.EXPECT().TryAcquireLock(gomock.Any(), int64(database.LockIDAutomationCron)).Return(true, nil)
mDB.EXPECT().GetActiveCronTriggers(gomock.Any()).Return(
[]database.GetActiveCronTriggersRow{trigger}, nil,
)
// Fire() checks rate limits before creating a chat.
mDB.EXPECT().CountAutomationChatCreatesInWindow(gomock.Any(), gomock.Any()).Return(int64(0), nil)
mDB.EXPECT().CountAutomationMessagesInWindow(gomock.Any(), gomock.Any()).Return(int64(0), nil)
// Expect the event to be inserted with status "created".
mDB.EXPECT().InsertAutomationEvent(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, arg database.InsertAutomationEventParams) (database.AutomationEvent, error) {
assert.Equal(t, trigger.AutomationID, arg.AutomationID)
assert.Equal(t, trigger.ID, arg.TriggerID.UUID)
assert.Equal(t, "created", arg.Status)
assert.True(t, arg.FilterMatched)
return database.AutomationEvent{ID: uuid.New()}, nil
},
)
// Expect last_triggered_at to be updated.
mDB.EXPECT().UpdateAutomationTriggerLastTriggeredAt(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, arg database.UpdateAutomationTriggerLastTriggeredAtParams) error {
assert.Equal(t, trigger.ID, arg.ID)
assert.Equal(t, now, arg.LastTriggeredAt)
return nil
},
)
done := awaitDoTick(ctx, t, clk)
scheduler := cronscheduler.New(ctx, testutil.Logger(t), mDB, clk, fakeChatCreator{})
<-done
require.NoError(t, scheduler.Close())
})
t.Run("NotYetDueTriggerSkipped", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
clk := quartz.NewMock(t)
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
clk.Set(now).MustWait(ctx)
// Trigger created now with "every hour" schedule. Next fire
// is 1 hour from now, so it should NOT fire.
trigger := makeTrigger("0 * * * *", "active", now)
mDB.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(mDB)
},
)
mDB.EXPECT().TryAcquireLock(gomock.Any(), int64(database.LockIDAutomationCron)).Return(true, nil)
mDB.EXPECT().GetActiveCronTriggers(gomock.Any()).Return(
[]database.GetActiveCronTriggersRow{trigger}, nil,
)
// No InsertAutomationEvent or UpdateAutomationTriggerLastTriggeredAt
// expected — the trigger is not due.
done := awaitDoTick(ctx, t, clk)
scheduler := cronscheduler.New(ctx, testutil.Logger(t), mDB, clk, nil)
<-done
require.NoError(t, scheduler.Close())
})
t.Run("PreviewModeCreatesPreviewEvent", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
clk := quartz.NewMock(t)
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
clk.Set(now).MustWait(ctx)
trigger := makeTrigger("* * * * *", "preview", now.Add(-2*time.Minute))
mDB.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(mDB)
},
)
mDB.EXPECT().TryAcquireLock(gomock.Any(), int64(database.LockIDAutomationCron)).Return(true, nil)
mDB.EXPECT().GetActiveCronTriggers(gomock.Any()).Return(
[]database.GetActiveCronTriggersRow{trigger}, nil,
)
mDB.EXPECT().InsertAutomationEvent(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, arg database.InsertAutomationEventParams) (database.AutomationEvent, error) {
assert.Equal(t, "preview", arg.Status)
return database.AutomationEvent{ID: uuid.New()}, nil
},
)
mDB.EXPECT().UpdateAutomationTriggerLastTriggeredAt(gomock.Any(), gomock.Any()).Return(nil)
done := awaitDoTick(ctx, t, clk)
scheduler := cronscheduler.New(ctx, testutil.Logger(t), mDB, clk, nil)
<-done
require.NoError(t, scheduler.Close())
})
t.Run("InvalidScheduleSkipped", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
clk := quartz.NewMock(t)
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
clk.Set(now).MustWait(ctx)
trigger := makeTrigger("not a cron", "active", now.Add(-2*time.Minute))
mDB.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(mDB)
},
)
mDB.EXPECT().TryAcquireLock(gomock.Any(), int64(database.LockIDAutomationCron)).Return(true, nil)
mDB.EXPECT().GetActiveCronTriggers(gomock.Any()).Return(
[]database.GetActiveCronTriggersRow{trigger}, nil,
)
// No event insert expected — invalid schedule is skipped.
done := awaitDoTick(ctx, t, clk)
scheduler := cronscheduler.New(ctx, testutil.Logger(t), mDB, clk, nil)
<-done
require.NoError(t, scheduler.Close())
})
t.Run("LastTriggeredAtPreventsRefire", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
clk := quartz.NewMock(t)
now := time.Date(2025, 6, 15, 12, 0, 30, 0, time.UTC)
clk.Set(now).MustWait(ctx)
// Trigger with "every hour" schedule that last fired at the
// top of this hour. Next fire is 13:00, which is after now.
trigger := makeTrigger("0 * * * *", "active", now.Add(-24*time.Hour))
trigger.LastTriggeredAt = sql.NullTime{
Time: time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC),
Valid: true,
}
mDB.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(mDB)
},
)
mDB.EXPECT().TryAcquireLock(gomock.Any(), int64(database.LockIDAutomationCron)).Return(true, nil)
mDB.EXPECT().GetActiveCronTriggers(gomock.Any()).Return(
[]database.GetActiveCronTriggersRow{trigger}, nil,
)
// No event insert — last_triggered_at means next fire is
// in the future.
done := awaitDoTick(ctx, t, clk)
scheduler := cronscheduler.New(ctx, testutil.Logger(t), mDB, clk, nil)
<-done
require.NoError(t, scheduler.Close())
})
t.Run("LockNotAcquiredSkips", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
clk := quartz.NewMock(t)
mDB.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(mDB)
},
)
// Another replica holds the lock.
mDB.EXPECT().TryAcquireLock(gomock.Any(), int64(database.LockIDAutomationCron)).Return(false, nil)
// No GetActiveCronTriggers call expected.
done := awaitDoTick(ctx, t, clk)
scheduler := cronscheduler.New(ctx, testutil.Logger(t), mDB, clk, nil)
<-done
require.NoError(t, scheduler.Close())
})
}
+288
View File
@@ -0,0 +1,288 @@
package automations
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
)
// ChatCreator abstracts the chatd.Server so the automations package
// does not depend on it directly. Both the webhook handler and the
// cron scheduler inject the real implementation.
type ChatCreator interface {
// CreateChat creates a new chat and sends the initial message.
// The returned UUID is the chat ID.
CreateChat(ctx context.Context, opts CreateChatOptions) (uuid.UUID, error)
// SendMessage appends a user message to an existing chat.
SendMessage(ctx context.Context, chatID uuid.UUID, ownerID uuid.UUID, content string) error
}
// CreateChatOptions contains everything needed to create an
// automation-initiated chat.
type CreateChatOptions struct {
OwnerID uuid.UUID
AutomationID uuid.UUID
Title string
Instructions string
ModelConfigID uuid.NullUUID
MCPServerIDs []uuid.UUID
Labels map[string]string
}
// FireResult is the outcome of an automation trigger firing.
type FireResult struct {
Status string
MatchedChatID uuid.NullUUID
CreatedChatID uuid.NullUUID
Error string
}
// FireOptions contains the inputs for firing an automation trigger.
type FireOptions struct {
// Automation fields.
AutomationID uuid.UUID
AutomationName string
AutomationStatus string
AutomationOwnerID uuid.UUID
AutomationInstructions string
AutomationModelConfigID uuid.NullUUID
AutomationMCPServerIDs []uuid.UUID
AutomationAllowedTools []string
MaxChatCreatesPerHour int32
MaxMessagesPerHour int32
// Trigger fields.
TriggerID uuid.UUID
// Resolved data.
Payload json.RawMessage
FilterMatched bool
ResolvedLabels map[string]string
}
// Fire executes the active-mode logic for an automation trigger:
// rate-limit check, find-or-create chat, send message, and record
// the event. For preview mode it only logs the event without
// creating a chat.
func Fire(
ctx context.Context,
logger slog.Logger,
db database.Store,
chat ChatCreator,
opts FireOptions,
) FireResult {
triggerUUID := uuid.NullUUID{UUID: opts.TriggerID, Valid: true}
// Resolve labels JSON for the event record.
var resolvedLabelsJSON pqtype.NullRawMessage
if len(opts.ResolvedLabels) > 0 {
if j, err := json.Marshal(opts.ResolvedLabels); err == nil {
resolvedLabelsJSON = pqtype.NullRawMessage{RawMessage: j, Valid: true}
}
}
// Preview mode: log the event, optionally look up a matching
// chat, but never create or continue one.
if opts.AutomationStatus == "preview" {
result := FireResult{Status: "preview"}
// Try to find a matching chat for the preview log.
if len(opts.ResolvedLabels) > 0 {
if chatID, ok := findChatByLabels(ctx, db, opts.AutomationOwnerID, opts.ResolvedLabels); ok {
result.MatchedChatID = uuid.NullUUID{UUID: chatID, Valid: true}
}
}
insertEvent(ctx, db, database.InsertAutomationEventParams{
AutomationID: opts.AutomationID,
TriggerID: triggerUUID,
Payload: opts.Payload,
FilterMatched: opts.FilterMatched,
ResolvedLabels: resolvedLabelsJSON,
MatchedChatID: result.MatchedChatID,
Status: "preview",
})
return result
}
// Active mode: enforce rate limits.
windowStart := time.Now().Add(-time.Hour)
chatCreates, err := db.CountAutomationChatCreatesInWindow(ctx, database.CountAutomationChatCreatesInWindowParams{
AutomationID: opts.AutomationID,
WindowStart: windowStart,
})
if err != nil {
logger.Error(ctx, "failed to count chat creates", slog.Error(err))
insertEvent(ctx, db, database.InsertAutomationEventParams{
AutomationID: opts.AutomationID,
TriggerID: triggerUUID,
Payload: opts.Payload,
FilterMatched: opts.FilterMatched,
ResolvedLabels: resolvedLabelsJSON,
Status: "error",
Error: sql.NullString{String: "failed to check rate limits", Valid: true},
})
return FireResult{Status: "error", Error: "failed to check rate limits"}
}
msgCount, err := db.CountAutomationMessagesInWindow(ctx, database.CountAutomationMessagesInWindowParams{
AutomationID: opts.AutomationID,
WindowStart: windowStart,
})
if err != nil {
logger.Error(ctx, "failed to count messages", slog.Error(err))
insertEvent(ctx, db, database.InsertAutomationEventParams{
AutomationID: opts.AutomationID,
TriggerID: triggerUUID,
Payload: opts.Payload,
FilterMatched: opts.FilterMatched,
ResolvedLabels: resolvedLabelsJSON,
Status: "error",
Error: sql.NullString{String: "failed to check rate limits", Valid: true},
})
return FireResult{Status: "error", Error: "failed to check rate limits"}
}
// Check message rate limit (applies to both create and continue).
if msgCount >= int64(opts.MaxMessagesPerHour) {
insertEvent(ctx, db, database.InsertAutomationEventParams{
AutomationID: opts.AutomationID,
TriggerID: triggerUUID,
Payload: opts.Payload,
FilterMatched: opts.FilterMatched,
ResolvedLabels: resolvedLabelsJSON,
Status: "rate_limited",
Error: sql.NullString{
String: fmt.Sprintf("message rate limit exceeded: %d/%d per hour", msgCount, opts.MaxMessagesPerHour),
Valid: true,
},
})
return FireResult{Status: "rate_limited", Error: "message rate limit exceeded"}
}
// Try to find an existing chat with matching labels to continue.
if len(opts.ResolvedLabels) > 0 {
if chatID, ok := findChatByLabels(ctx, db, opts.AutomationOwnerID, opts.ResolvedLabels); ok {
// Continue existing chat.
if err := chat.SendMessage(ctx, chatID, opts.AutomationOwnerID, opts.AutomationInstructions); err != nil {
logger.Error(ctx, "failed to send message to existing chat",
slog.F("chat_id", chatID),
slog.Error(err),
)
insertEvent(ctx, db, database.InsertAutomationEventParams{
AutomationID: opts.AutomationID,
TriggerID: triggerUUID,
Payload: opts.Payload,
FilterMatched: opts.FilterMatched,
ResolvedLabels: resolvedLabelsJSON,
MatchedChatID: uuid.NullUUID{UUID: chatID, Valid: true},
Status: "error",
Error: sql.NullString{String: "failed to send message to chat", Valid: true},
})
return FireResult{Status: "error", Error: "failed to send message"}
}
insertEvent(ctx, db, database.InsertAutomationEventParams{
AutomationID: opts.AutomationID,
TriggerID: triggerUUID,
Payload: opts.Payload,
FilterMatched: opts.FilterMatched,
ResolvedLabels: resolvedLabelsJSON,
MatchedChatID: uuid.NullUUID{UUID: chatID, Valid: true},
Status: "continued",
})
return FireResult{
Status: "continued",
MatchedChatID: uuid.NullUUID{UUID: chatID, Valid: true},
}
}
}
// No matching chat found — create a new one.
// Check chat creation rate limit.
if chatCreates >= int64(opts.MaxChatCreatesPerHour) {
insertEvent(ctx, db, database.InsertAutomationEventParams{
AutomationID: opts.AutomationID,
TriggerID: triggerUUID,
Payload: opts.Payload,
FilterMatched: opts.FilterMatched,
ResolvedLabels: resolvedLabelsJSON,
Status: "rate_limited",
Error: sql.NullString{
String: fmt.Sprintf("chat creation rate limit exceeded: %d/%d per hour", chatCreates, opts.MaxChatCreatesPerHour),
Valid: true,
},
})
return FireResult{Status: "rate_limited", Error: "chat creation rate limit exceeded"}
}
newChatID, err := chat.CreateChat(ctx, CreateChatOptions{
OwnerID: opts.AutomationOwnerID,
AutomationID: opts.AutomationID,
Title: fmt.Sprintf("[%s] %s", opts.AutomationName, time.Now().UTC().Format("2006-01-02 15:04")),
Instructions: opts.AutomationInstructions,
ModelConfigID: opts.AutomationModelConfigID,
MCPServerIDs: opts.AutomationMCPServerIDs,
Labels: opts.ResolvedLabels,
})
if err != nil {
logger.Error(ctx, "failed to create chat", slog.Error(err))
insertEvent(ctx, db, database.InsertAutomationEventParams{
AutomationID: opts.AutomationID,
TriggerID: triggerUUID,
Payload: opts.Payload,
FilterMatched: opts.FilterMatched,
ResolvedLabels: resolvedLabelsJSON,
Status: "error",
Error: sql.NullString{String: "failed to create chat", Valid: true},
})
return FireResult{Status: "error", Error: "failed to create chat"}
}
insertEvent(ctx, db, database.InsertAutomationEventParams{
AutomationID: opts.AutomationID,
TriggerID: triggerUUID,
Payload: opts.Payload,
FilterMatched: opts.FilterMatched,
ResolvedLabels: resolvedLabelsJSON,
CreatedChatID: uuid.NullUUID{UUID: newChatID, Valid: true},
Status: "created",
})
return FireResult{
Status: "created",
CreatedChatID: uuid.NullUUID{UUID: newChatID, Valid: true},
}
}
// findChatByLabels looks up an existing chat owned by the given user
// whose labels match the resolved label set.
func findChatByLabels(ctx context.Context, db database.Store, ownerID uuid.UUID, labels map[string]string) (uuid.UUID, bool) {
labelsJSON, err := json.Marshal(labels)
if err != nil {
return uuid.Nil, false
}
chats, err := db.GetChats(ctx, database.GetChatsParams{
OwnerID: ownerID,
LabelFilter: pqtype.NullRawMessage{
RawMessage: labelsJSON,
Valid: true,
},
LimitOpt: 1,
})
if err != nil || len(chats) == 0 {
return uuid.Nil, false
}
return chats[0].ID, true
}
// insertEvent is a fire-and-forget helper that logs errors but never
// fails the caller.
func insertEvent(ctx context.Context, db database.Store, params database.InsertAutomationEventParams) {
_, _ = db.InsertAutomationEvent(ctx, params)
}
+34
View File
@@ -0,0 +1,34 @@
package automations
import (
"encoding/json"
"reflect"
"github.com/tidwall/gjson"
)
// MatchFilter evaluates a gjson-based filter against a JSON payload.
// If filter is nil or empty, the match succeeds (everything passes).
// Each key in the filter is a gjson path; each value is the expected
// result. All entries must match for the filter to pass.
func MatchFilter(payload string, filter json.RawMessage) bool {
if len(filter) == 0 || string(filter) == "null" {
return true
}
var conditions map[string]any
if err := json.Unmarshal(filter, &conditions); err != nil {
return false
}
for path, expected := range conditions {
result := gjson.Get(payload, path)
if !result.Exists() {
return false
}
if !reflect.DeepEqual(result.Value(), expected) {
return false
}
}
return true
}
+48
View File
@@ -0,0 +1,48 @@
package automations_test
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/automations"
)
func TestMatchFilter(t *testing.T) {
t.Parallel()
payload := `{"action":"opened","repository":{"full_name":"coder/coder","private":true},"pull_request":{"number":42},"labels":["bug","urgent"],"sender":{"login":"octocat"}}`
tests := []struct {
name string
filter json.RawMessage
want bool
}{
{"nil filter matches everything", nil, true},
{"null filter matches everything", json.RawMessage(`null`), true},
{"empty filter matches everything", json.RawMessage(`{}`), true},
{"single match", json.RawMessage(`{"action":"opened"}`), true},
{"nested match", json.RawMessage(`{"repository.full_name":"coder/coder"}`), true},
{"multiple conditions all match", json.RawMessage(`{"action":"opened","repository.full_name":"coder/coder"}`), true},
{"value mismatch", json.RawMessage(`{"action":"closed"}`), false},
{"path does not exist", json.RawMessage(`{"nonexistent":"value"}`), false},
{"one of two conditions fails", json.RawMessage(`{"action":"opened","repository.full_name":"other/repo"}`), false},
{"numeric match", json.RawMessage(`{"pull_request.number":42}`), true},
{"numeric mismatch", json.RawMessage(`{"pull_request.number":99}`), false},
{"invalid filter json", json.RawMessage(`not json`), false},
{"boolean match", json.RawMessage(`{"repository.private":true}`), true},
{"boolean mismatch", json.RawMessage(`{"repository.private":false}`), false},
{"array value match", json.RawMessage(`{"labels":["bug","urgent"]}`), true},
{"array value mismatch", json.RawMessage(`{"labels":["bug"]}`), false},
{"object value match", json.RawMessage(`{"sender":{"login":"octocat"}}`), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := automations.MatchFilter(payload, tt.filter)
require.Equal(t, tt.want, got)
})
}
}
+23
View File
@@ -0,0 +1,23 @@
package automations
import (
"github.com/tidwall/gjson"
)
// ResolveLabels extracts label values from a JSON payload using gjson
// paths. Each key in labelPaths maps a label name to a gjson path
// expression. If a path doesn't match, that label is omitted.
func ResolveLabels(payload string, labelPaths map[string]string) map[string]string {
if len(labelPaths) == 0 {
return nil
}
labels := make(map[string]string, len(labelPaths))
for labelKey, gjsonPath := range labelPaths {
result := gjson.Get(payload, gjsonPath)
if result.Exists() {
labels[labelKey] = result.String()
}
}
return labels
}
+79
View File
@@ -0,0 +1,79 @@
package automations_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/automations"
)
func TestResolveLabels(t *testing.T) {
t.Parallel()
payload := `{"repository":{"full_name":"coder/coder"},"action":"opened","number":42,"draft":false}`
tests := []struct {
name string
labelPaths map[string]string
want map[string]string
}{
{
"nil label paths returns nil",
nil,
nil,
},
{
"empty label paths returns nil",
map[string]string{},
nil,
},
{
"simple path extraction",
map[string]string{"repo": "repository.full_name"},
map[string]string{"repo": "coder/coder"},
},
{
"multiple paths",
map[string]string{
"repo": "repository.full_name",
"action": "action",
},
map[string]string{
"repo": "coder/coder",
"action": "opened",
},
},
{
"missing path omitted",
map[string]string{
"repo": "repository.full_name",
"missing": "nonexistent.path",
},
map[string]string{"repo": "coder/coder"},
},
{
"numeric value coerced to string",
map[string]string{"num": "number"},
map[string]string{"num": "42"},
},
{
"boolean value coerced to string",
map[string]string{"draft": "draft"},
map[string]string{"draft": "false"},
},
{
"all paths missing returns empty map",
map[string]string{"a": "no.such.path", "b": "also.missing"},
map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := automations.ResolveLabels(payload, tt.labelPaths)
require.Equal(t, tt.want, got)
})
}
}
+33
View File
@@ -0,0 +1,33 @@
package automations
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"strings"
)
// VerifySignature checks an HMAC-SHA256 signature in the format used
// by GitHub and many other webhook providers: "sha256=<hex-digest>".
// The comparison uses constant-time equality to prevent timing attacks.
func VerifySignature(payload []byte, secret string, signatureHeader string) bool {
if secret == "" || signatureHeader == "" {
return false
}
parts := strings.SplitN(signatureHeader, "=", 2)
if len(parts) != 2 || parts[0] != "sha256" {
return false
}
expectedMAC, err := hex.DecodeString(parts[1])
if err != nil {
return false
}
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(payload)
actualMAC := mac.Sum(nil)
return hmac.Equal(actualMAC, expectedMAC)
}
+48
View File
@@ -0,0 +1,48 @@
package automations_test
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/automations"
)
func TestVerifySignature(t *testing.T) {
t.Parallel()
secret := "test-secret"
payload := []byte(`{"action":"opened"}`)
// Compute valid signature.
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(payload)
validSig := "sha256=" + hex.EncodeToString(mac.Sum(nil))
tests := []struct {
name string
secret string
sig string
want bool
}{
{"valid signature", secret, validSig, true},
{"wrong secret", "wrong-secret", validSig, false},
{"empty signature header", secret, "", false},
{"empty secret", "", validSig, false},
{"missing sha256 prefix", secret, hex.EncodeToString(mac.Sum(nil)), false},
{"wrong prefix", secret, "sha512=" + hex.EncodeToString(mac.Sum(nil)), false},
{"invalid hex", secret, "sha256=zzzz", false},
{"tampered payload signature", secret, "sha256=" + hex.EncodeToString([]byte("wrong")), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := automations.VerifySignature(payload, tt.secret, tt.sig)
require.Equal(t, tt.want, got)
})
}
}
+37 -12
View File
@@ -168,7 +168,6 @@ type Options struct {
ConnectionLogger connectionlog.ConnectionLogger
AgentConnectionUpdateFrequency time.Duration
AgentInactiveDisconnectTimeout time.Duration
ChatdInstructionLookupTimeout time.Duration
AWSCertificates awsidentity.Certificates
Authorizer rbac.Authorizer
AzureCertificates x509.VerifyOptions
@@ -783,10 +782,9 @@ func New(options *Options) *API {
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
ProviderAPIKeys: ChatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
@@ -1151,13 +1149,37 @@ func New(options *Options) *API {
})
})
})
// Experimental(agents): automation API routes gated by ExperimentAgents.
r.Route("/automations", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents),
)
r.Post("/", api.postAutomation)
r.Get("/", api.listAutomations)
r.Route("/{automation}", func(r chi.Router) {
r.Use(httpmw.ExtractAutomationParam(options.Database))
r.Get("/", api.getAutomation)
r.Patch("/", api.patchAutomation)
r.Delete("/", api.deleteAutomation)
r.Get("/events", api.listAutomationEvents)
r.Post("/test", api.testAutomation)
r.Route("/triggers", func(r chi.Router) {
r.Post("/", api.postAutomationTrigger)
r.Get("/", api.listAutomationTriggers)
r.Route("/{trigger}", func(r chi.Router) {
r.Delete("/", api.deleteAutomationTrigger)
r.Post("/regenerate-secret", api.regenerateAutomationTriggerSecret)
})
})
})
})
// Experimental(agents): chat API routes gated by ExperimentAgents.
r.Route("/chats", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents),
)
r.Get("/by-workspace", api.chatsByWorkspace)
r.Get("/", api.listChats)
r.Post("/", api.postChats)
r.Get("/models", api.listChatModels)
@@ -1223,13 +1245,6 @@ func New(options *Options) *API {
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
})
})
r.Route("/user-provider-configs", func(r chi.Router) {
r.Get("/", api.listUserChatProviderConfigs)
r.Route("/{providerConfig}", func(r chi.Router) {
r.Put("/", api.upsertUserChatProviderKey)
r.Delete("/", api.deleteUserChatProviderKey)
})
})
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
@@ -1243,7 +1258,6 @@ func New(options *Options) *API {
r.Get("/git", api.watchChatGit)
})
r.Post("/interrupt", api.interruptChat)
r.Post("/title/regenerate", api.regenerateChatTitle)
r.Get("/diff", api.getChatDiffContents)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
r.Delete("/", api.deleteChatQueuedMessage)
@@ -1310,6 +1324,11 @@ func New(options *Options) *API {
// r.Use(apiKeyMiddleware)
r.Get("/", api.derpMapUpdates)
})
// Unauthenticated webhook endpoint for automation triggers.
// Authentication is via HMAC signature, not API key.
r.Route("/automations/triggers/{trigger_id}/webhook", func(r chi.Router) {
r.Post("/", api.postAutomationWebhook)
})
r.Route("/deployment", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/config", api.deploymentValues)
@@ -2123,6 +2142,12 @@ type API struct {
ProfileCollecting atomic.Bool
}
// ChatDaemon returns the chatd server used for automation-initiated
// chat creation and messaging.
func (api *API) ChatDaemon() *chatd.Server {
return api.chatDaemon
}
// Close waits for all WebSocket connections to drain before returning.
func (api *API) Close() error {
select {
+6 -8
View File
@@ -149,13 +149,12 @@ type Options struct {
OneTimePasscodeValidityPeriod time.Duration
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
IncludeProvisionerDaemon bool
ChatdInstructionLookupTimeout time.Duration
ProvisionerDaemonVersion string
ProvisionerDaemonTags map[string]string
MetricsCacheRefreshInterval time.Duration
AgentStatsRefreshInterval time.Duration
DeploymentValues *codersdk.DeploymentValues
IncludeProvisionerDaemon bool
ProvisionerDaemonVersion string
ProvisionerDaemonTags map[string]string
MetricsCacheRefreshInterval time.Duration
AgentStatsRefreshInterval time.Duration
DeploymentValues *codersdk.DeploymentValues
// Set update check options to enable update check.
UpdateCheckOptions *updatecheck.Options
@@ -576,7 +575,6 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
// Force a long disconnection timeout to ensure
// agents are not marked as disconnected during slow tests.
AgentInactiveDisconnectTimeout: testutil.WaitShort,
ChatdInstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
AccessURL: accessURL,
AppHostname: options.AppHostname,
AppHostnameRegex: appHostnameRegex,
+2 -2
View File
@@ -90,8 +90,8 @@ func (m *FakeConnectionLogger) Contains(t testing.TB, expected database.UpsertCo
t.Logf("connection log %d: expected Code %d, got %d", idx+1, expected.Code.Int32, cl.Code.Int32)
continue
}
if expected.IP.Valid && cl.IP.IPNet.String() != expected.IP.IPNet.String() {
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.IP.IPNet, cl.IP.IPNet)
if expected.Ip.Valid && cl.Ip.IPNet.String() != expected.Ip.IPNet.String() {
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.Ip.IPNet, cl.Ip.IPNet)
continue
}
if expected.UserAgent.Valid && cl.UserAgent.String != expected.UserAgent.String {
+5 -2
View File
@@ -7,10 +7,14 @@ type CheckConstraint string
// CheckConstraint enums.
const (
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckAutomationEventsStatusCheck CheckConstraint = "automation_events_status_check" // automation_events
CheckAutomationTriggersTypeCheck CheckConstraint = "automation_triggers_type_check" // automation_triggers
CheckAutomationsMaxChatCreatesPerHourCheck CheckConstraint = "automations_max_chat_creates_per_hour_check" // automations
CheckAutomationsMaxMessagesPerHourCheck CheckConstraint = "automations_max_messages_per_hour_check" // automations
CheckAutomationsStatusCheck CheckConstraint = "automations_status_check" // automations
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
@@ -33,5 +37,4 @@ const (
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys
)
+95 -416
View File
@@ -999,16 +999,15 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
return sdkToolUsages[i].CreatedAt.Before(sdkToolUsages[j].CreatedAt)
})
intc := codersdk.AIBridgeInterception{
ID: interception.ID,
Initiator: MinimalUserFromVisibleUser(initiator),
Provider: interception.Provider,
ProviderName: interception.ProviderName,
Model: interception.Model,
Metadata: jsonOrEmptyMap(interception.Metadata),
StartedAt: interception.StartedAt,
TokenUsages: sdkTokenUsages,
UserPrompts: sdkUserPrompts,
ToolUsages: sdkToolUsages,
ID: interception.ID,
Initiator: MinimalUserFromVisibleUser(initiator),
Provider: interception.Provider,
Model: interception.Model,
Metadata: jsonOrEmptyMap(interception.Metadata),
StartedAt: interception.StartedAt,
TokenUsages: sdkTokenUsages,
UserPrompts: sdkUserPrompts,
ToolUsages: sdkToolUsages,
}
if interception.APIKeyID.Valid {
intc.APIKeyID = &interception.APIKeyID.String
@@ -1037,10 +1036,8 @@ func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSess
StartedAt: row.StartedAt,
Threads: row.Threads,
TokenUsageSummary: codersdk.AIBridgeSessionTokenUsageSummary{
InputTokens: row.InputTokens,
OutputTokens: row.OutputTokens,
CacheReadInputTokens: row.CacheReadInputTokens,
CacheWriteInputTokens: row.CacheWriteInputTokens,
InputTokens: row.InputTokens,
OutputTokens: row.OutputTokens,
},
}
// Ensure non-nil slices for JSON serialization.
@@ -1064,15 +1061,13 @@ func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSess
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
return codersdk.AIBridgeTokenUsage{
ID: usage.ID,
InterceptionID: usage.InterceptionID,
ProviderResponseID: usage.ProviderResponseID,
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
CacheWriteInputTokens: usage.CacheWriteInputTokens,
Metadata: jsonOrEmptyMap(usage.Metadata),
CreatedAt: usage.CreatedAt,
ID: usage.ID,
InterceptionID: usage.InterceptionID,
ProviderResponseID: usage.ProviderResponseID,
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
Metadata: jsonOrEmptyMap(usage.Metadata),
CreatedAt: usage.CreatedAt,
}
}
@@ -1102,291 +1097,6 @@ func AIBridgeToolUsage(usage database.AIBridgeToolUsage) codersdk.AIBridgeToolUs
}
}
// AIBridgeSessionThreads converts session metadata and thread interceptions
// into the threads response. It groups interceptions into threads, builds
// agentic actions from tool usages and model thoughts, and aggregates
// token usage with metadata.
func AIBridgeSessionThreads(
session database.ListAIBridgeSessionsRow,
interceptions []database.ListAIBridgeSessionThreadsRow,
tokenUsages []database.AIBridgeTokenUsage,
toolUsages []database.AIBridgeToolUsage,
userPrompts []database.AIBridgeUserPrompt,
modelThoughts []database.AIBridgeModelThought,
) codersdk.AIBridgeSessionThreadsResponse {
// Index subresources by interception ID.
tokensByInterception := make(map[uuid.UUID][]database.AIBridgeTokenUsage, len(interceptions))
for _, tu := range tokenUsages {
tokensByInterception[tu.InterceptionID] = append(tokensByInterception[tu.InterceptionID], tu)
}
toolsByInterception := make(map[uuid.UUID][]database.AIBridgeToolUsage, len(interceptions))
for _, tu := range toolUsages {
toolsByInterception[tu.InterceptionID] = append(toolsByInterception[tu.InterceptionID], tu)
}
promptsByInterception := make(map[uuid.UUID][]database.AIBridgeUserPrompt, len(interceptions))
for _, up := range userPrompts {
promptsByInterception[up.InterceptionID] = append(promptsByInterception[up.InterceptionID], up)
}
thoughtsByInterception := make(map[uuid.UUID][]database.AIBridgeModelThought, len(interceptions))
for _, mt := range modelThoughts {
thoughtsByInterception[mt.InterceptionID] = append(thoughtsByInterception[mt.InterceptionID], mt)
}
// Group interceptions by thread_id, preserving the order returned by the
// SQL query.
interceptionsByThread := make(map[uuid.UUID][]database.AIBridgeInterception, len(interceptions))
var threadIDs []uuid.UUID
for _, row := range interceptions {
if _, ok := interceptionsByThread[row.ThreadID]; !ok {
threadIDs = append(threadIDs, row.ThreadID)
}
interceptionsByThread[row.ThreadID] = append(interceptionsByThread[row.ThreadID], row.AIBridgeInterception)
}
// Build threads and track page time bounds.
threads := make([]codersdk.AIBridgeThread, 0, len(threadIDs))
var pageStartedAt, pageEndedAt *time.Time
for _, threadID := range threadIDs {
intcs := interceptionsByThread[threadID]
thread := buildAIBridgeThread(threadID, intcs, tokensByInterception, toolsByInterception, promptsByInterception, thoughtsByInterception)
for _, intc := range intcs {
if pageStartedAt == nil || intc.StartedAt.Before(*pageStartedAt) {
t := intc.StartedAt
pageStartedAt = &t
}
if intc.EndedAt.Valid {
if pageEndedAt == nil || intc.EndedAt.Time.After(*pageEndedAt) {
t := intc.EndedAt.Time
pageEndedAt = &t
}
}
}
threads = append(threads, thread)
}
// Aggregate session-level token usage metadata from all token
// usages in the session (not just the page).
sessionTokenMeta := aggregateTokenMetadata(tokenUsages)
resp := codersdk.AIBridgeSessionThreadsResponse{
ID: session.SessionID,
Initiator: MinimalUserFromVisibleUser(database.VisibleUser{
ID: session.UserID,
Username: session.UserUsername,
Name: session.UserName,
AvatarURL: session.UserAvatarUrl,
}),
Providers: session.Providers,
Models: session.Models,
Metadata: jsonOrEmptyMap(pqtype.NullRawMessage{RawMessage: session.Metadata, Valid: len(session.Metadata) > 0}),
StartedAt: session.StartedAt,
PageStartedAt: pageStartedAt,
PageEndedAt: pageEndedAt,
TokenUsageSummary: codersdk.AIBridgeSessionThreadsTokenUsage{
InputTokens: session.InputTokens,
OutputTokens: session.OutputTokens,
CacheReadInputTokens: session.CacheReadInputTokens,
CacheWriteInputTokens: session.CacheWriteInputTokens,
Metadata: sessionTokenMeta,
},
Threads: threads,
}
if resp.Providers == nil {
resp.Providers = []string{}
}
if resp.Models == nil {
resp.Models = []string{}
}
if session.Client != "" {
resp.Client = &session.Client
}
if !session.EndedAt.IsZero() {
resp.EndedAt = &session.EndedAt
}
return resp
}
func buildAIBridgeThread(
threadID uuid.UUID,
interceptions []database.AIBridgeInterception,
tokensByInterception map[uuid.UUID][]database.AIBridgeTokenUsage,
toolsByInterception map[uuid.UUID][]database.AIBridgeToolUsage,
promptsByInterception map[uuid.UUID][]database.AIBridgeUserPrompt,
thoughtsByInterception map[uuid.UUID][]database.AIBridgeModelThought,
) codersdk.AIBridgeThread {
// Find the root interception (where id == threadID) to get the
// thread prompt and model.
var rootIntc *database.AIBridgeInterception
for i := range interceptions {
if interceptions[i].ID == threadID {
rootIntc = &interceptions[i]
break
}
}
// Fallback to first interception if root not found.
if rootIntc == nil && len(interceptions) > 0 {
rootIntc = &interceptions[0]
}
thread := codersdk.AIBridgeThread{
ID: threadID,
}
if rootIntc != nil {
thread.Model = rootIntc.Model
thread.Provider = rootIntc.Provider
// Get first user prompt from root interception.
// A thread can only have one prompt, by definition, since we currently
// only store the last prompt observed in an interception.
if prompts := promptsByInterception[rootIntc.ID]; len(prompts) > 0 {
thread.Prompt = &prompts[0].Prompt
}
}
// Compute thread time bounds from interceptions.
for _, intc := range interceptions {
if thread.StartedAt.IsZero() || intc.StartedAt.Before(thread.StartedAt) {
thread.StartedAt = intc.StartedAt
}
if intc.EndedAt.Valid {
if thread.EndedAt == nil || intc.EndedAt.Time.After(*thread.EndedAt) {
t := intc.EndedAt.Time
thread.EndedAt = &t
}
}
}
// Build agentic actions grouped by interception. Each interception that
// has tool calls produces one action with all its tool calls, thinking
// blocks, and token usage.
var actions []codersdk.AIBridgeAgenticAction
for _, intc := range interceptions {
tools := toolsByInterception[intc.ID]
if len(tools) == 0 {
continue
}
// Thinking blocks for this interception.
thoughts := thoughtsByInterception[intc.ID]
thinking := make([]codersdk.AIBridgeModelThought, 0, len(thoughts))
for _, mt := range thoughts {
thinking = append(thinking, codersdk.AIBridgeModelThought{
Text: mt.Content,
})
}
// Token usage for the interception.
actionTokenUsage := aggregateTokenUsage(tokensByInterception[intc.ID])
// Build tool call list.
toolCalls := make([]codersdk.AIBridgeToolCall, 0, len(tools))
for _, tu := range tools {
toolCalls = append(toolCalls, codersdk.AIBridgeToolCall{
ID: tu.ID,
InterceptionID: tu.InterceptionID,
ProviderResponseID: tu.ProviderResponseID,
ServerURL: tu.ServerUrl.String,
Tool: tu.Tool,
Injected: tu.Injected,
Input: tu.Input,
Metadata: jsonOrEmptyMap(tu.Metadata),
CreatedAt: tu.CreatedAt,
})
}
actions = append(actions, codersdk.AIBridgeAgenticAction{
Model: intc.Model,
TokenUsage: actionTokenUsage,
Thinking: thinking,
ToolCalls: toolCalls,
})
}
if actions == nil {
// Make an empty slice so we don't serialize `null`.
actions = make([]codersdk.AIBridgeAgenticAction, 0)
}
thread.AgenticActions = actions
// Aggregate thread-level token usage.
var threadTokens []database.AIBridgeTokenUsage
for _, intc := range interceptions {
threadTokens = append(threadTokens, tokensByInterception[intc.ID]...)
}
thread.TokenUsage = aggregateTokenUsage(threadTokens)
return thread
}
// aggregateTokenUsage sums token usage rows and aggregates metadata.
func aggregateTokenUsage(tokens []database.AIBridgeTokenUsage) codersdk.AIBridgeSessionThreadsTokenUsage {
var inputTokens, outputTokens, cacheRead, cacheWrite int64
for _, tu := range tokens {
inputTokens += tu.InputTokens
outputTokens += tu.OutputTokens
cacheRead += tu.CacheReadInputTokens
cacheWrite += tu.CacheWriteInputTokens
}
return codersdk.AIBridgeSessionThreadsTokenUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cacheRead,
CacheWriteInputTokens: cacheWrite,
Metadata: aggregateTokenMetadata(tokens),
}
}
// aggregateTokenMetadata sums all numeric values from the metadata
// JSONB across the given token usage rows by key. Nested objects are
// flattened using dot-notation (e.g. {"cache": {"read_tokens": 10}}
// becomes "cache.read_tokens"). Non-numeric leaves (strings,
// booleans, arrays, nulls) are silently skipped.
func aggregateTokenMetadata(tokens []database.AIBridgeTokenUsage) map[string]any {
sums := make(map[string]int64)
for _, tu := range tokens {
if !tu.Metadata.Valid || len(tu.Metadata.RawMessage) == 0 {
continue
}
var m map[string]json.RawMessage
if err := json.Unmarshal(tu.Metadata.RawMessage, &m); err != nil {
continue
}
flattenAndSum(sums, "", m)
}
result := make(map[string]any, len(sums))
for k, v := range sums {
result[k] = v
}
return result
}
// flattenAndSum recursively walks a JSON object and sums all numeric
// leaf values into sums, using dot-separated keys for nested objects.
func flattenAndSum(sums map[string]int64, prefix string, m map[string]json.RawMessage) {
for k, raw := range m {
key := k
if prefix != "" {
key = prefix + "." + k
}
// Try as a number first.
var n json.Number
if err := json.Unmarshal(raw, &n); err == nil {
if v, err := n.Int64(); err == nil {
sums[key] += v
}
continue
}
// Try as a nested object.
var nested map[string]json.RawMessage
if err := json.Unmarshal(raw, &nested); err == nil {
flattenAndSum(sums, key, nested)
}
// Arrays, strings, booleans, nulls are skipped.
}
}
func InvalidatedPresets(invalidatedPresets []database.UpdatePresetsLastInvalidatedAtRow) []codersdk.InvalidatedPreset {
var presets []codersdk.InvalidatedPreset
for _, p := range invalidatedPresets {
@@ -1525,114 +1235,6 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
return &value
}
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
// nil slices and maps to empty values for JSON serialization and
// derives RootChatID from the parent chain when not explicitly set.
// When diffStatus is non-nil the response includes diff metadata.
// When files is non-empty the response includes file metadata;
// pass nil to omit the files field (e.g. list endpoints).
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database.GetChatFileMetadataByChatIDRow) codersdk.Chat {
mcpServerIDs := c.MCPServerIDs
if mcpServerIDs == nil {
mcpServerIDs = []uuid.UUID{}
}
labels := map[string]string(c.Labels)
if labels == nil {
labels = map[string]string{}
}
chat := codersdk.Chat{
ID: c.ID,
OwnerID: c.OwnerID,
LastModelConfigID: c.LastModelConfigID,
Title: c.Title,
Status: codersdk.ChatStatus(c.Status),
Archived: c.Archived,
PinOrder: c.PinOrder,
CreatedAt: c.CreatedAt,
UpdatedAt: c.UpdatedAt,
MCPServerIDs: mcpServerIDs,
Labels: labels,
}
if c.LastError.Valid {
chat.LastError = &c.LastError.String
}
if c.ParentChatID.Valid {
parentChatID := c.ParentChatID.UUID
chat.ParentChatID = &parentChatID
}
switch {
case c.RootChatID.Valid:
rootChatID := c.RootChatID.UUID
chat.RootChatID = &rootChatID
case c.ParentChatID.Valid:
rootChatID := c.ParentChatID.UUID
chat.RootChatID = &rootChatID
default:
rootChatID := c.ID
chat.RootChatID = &rootChatID
}
if c.WorkspaceID.Valid {
chat.WorkspaceID = &c.WorkspaceID.UUID
}
if c.BuildID.Valid {
chat.BuildID = &c.BuildID.UUID
}
if c.AgentID.Valid {
chat.AgentID = &c.AgentID.UUID
}
if diffStatus != nil {
convertedDiffStatus := ChatDiffStatus(c.ID, diffStatus)
chat.DiffStatus = &convertedDiffStatus
}
if len(files) > 0 {
chat.Files = make([]codersdk.ChatFileMetadata, 0, len(files))
for _, row := range files {
chat.Files = append(chat.Files, codersdk.ChatFileMetadata{
ID: row.ID,
OwnerID: row.OwnerID,
OrganizationID: row.OrganizationID,
Name: row.Name,
MimeType: row.Mimetype,
CreatedAt: row.CreatedAt,
})
}
}
if c.LastInjectedContext.Valid {
var parts []codersdk.ChatMessagePart
// Internal fields are stripped at write time in
// chatd.updateLastInjectedContext, so no
// StripInternal call is needed here. Unmarshal
// errors are suppressed — the column is written by
// us with a known schema.
if err := json.Unmarshal(c.LastInjectedContext.RawMessage, &parts); err == nil {
chat.LastInjectedContext = parts
}
}
return chat
}
// ChatRows converts a slice of database.GetChatsRow (which embeds
// Chat plus HasUnread) to codersdk.Chat, looking up diff statuses
// from the provided map. When diffStatusesByChatID is non-nil,
// chats without an entry receive an empty DiffStatus.
func ChatRows(rows []database.GetChatsRow, diffStatusesByChatID map[uuid.UUID]database.ChatDiffStatus) []codersdk.Chat {
result := make([]codersdk.Chat, len(rows))
for i, row := range rows {
diffStatus, ok := diffStatusesByChatID[row.Chat.ID]
if ok {
result[i] = Chat(row.Chat, &diffStatus, nil)
} else {
result[i] = Chat(row.Chat, nil, nil)
if diffStatusesByChatID != nil {
emptyDiffStatus := ChatDiffStatus(row.Chat.ID, nil)
result[i].DiffStatus = &emptyDiffStatus
}
}
result[i].HasUnread = row.HasUnread
}
return result
}
// ChatDiffStatus converts a database.ChatDiffStatus to a
// codersdk.ChatDiffStatus. When status is nil an empty value
// containing only the chatID is returned.
@@ -1715,3 +1317,80 @@ func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.
return result
}
// Automation converts a database Automation to a codersdk Automation.
func Automation(a database.Automation) codersdk.Automation {
result := codersdk.Automation{
ID: a.ID,
OwnerID: a.OwnerID,
OrganizationID: a.OrganizationID,
Name: a.Name,
Description: a.Description,
Instructions: a.Instructions,
MCPServerIDs: a.MCPServerIDs,
AllowedTools: a.AllowedTools,
Status: codersdk.AutomationStatus(a.Status),
MaxChatCreatesPerHour: a.MaxChatCreatesPerHour,
MaxMessagesPerHour: a.MaxMessagesPerHour,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
}
if a.ModelConfigID.Valid {
result.ModelConfigID = &a.ModelConfigID.UUID
}
if a.MCPServerIDs == nil {
result.MCPServerIDs = []uuid.UUID{}
}
if a.AllowedTools == nil {
result.AllowedTools = []string{}
}
return result
}
// AutomationEvent converts a database AutomationEvent to a codersdk
// AutomationEvent.
func AutomationEvent(e database.AutomationEvent) codersdk.AutomationEvent {
result := codersdk.AutomationEvent{
ID: e.ID,
AutomationID: e.AutomationID,
ReceivedAt: e.ReceivedAt,
Payload: e.Payload,
FilterMatched: e.FilterMatched,
ResolvedLabels: e.ResolvedLabels.RawMessage,
Status: codersdk.AutomationEventStatus(e.Status),
}
if e.TriggerID.Valid {
result.TriggerID = &e.TriggerID.UUID
}
if e.MatchedChatID.Valid {
result.MatchedChatID = &e.MatchedChatID.UUID
}
if e.CreatedChatID.Valid {
result.CreatedChatID = &e.CreatedChatID.UUID
}
if e.Error.Valid {
result.Error = &e.Error.String
}
return result
}
// AutomationTrigger converts a database AutomationTrigger to a codersdk
// AutomationTrigger.
func AutomationTrigger(t database.AutomationTrigger, accessURL string) codersdk.AutomationTrigger {
result := codersdk.AutomationTrigger{
ID: t.ID,
AutomationID: t.AutomationID,
Type: codersdk.AutomationTriggerType(t.Type),
Filter: t.Filter.RawMessage,
LabelPaths: t.LabelPaths.RawMessage,
CreatedAt: t.CreatedAt,
UpdatedAt: t.UpdatedAt,
}
if t.Type == "webhook" {
result.WebhookURL = accessURL + "/api/v2/automations/triggers/" + t.ID.String() + "/webhook"
}
if t.CronSchedule.Valid {
result.CronSchedule = &t.CronSchedule.String
}
return result
}
@@ -1,308 +0,0 @@
package db2sdk
import (
"encoding/json"
"testing"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
)
func TestAggregateTokenMetadata(t *testing.T) {
t.Parallel()
t.Run("empty_input", func(t *testing.T) {
t.Parallel()
result := aggregateTokenMetadata(nil)
require.Empty(t, result)
})
t.Run("sums_across_rows", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"cache_read_tokens":100,"reasoning_tokens":50}`),
Valid: true,
},
},
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"cache_read_tokens":200,"reasoning_tokens":75}`),
Valid: true,
},
},
}
result := aggregateTokenMetadata(tokens)
require.Equal(t, int64(300), result["cache_read_tokens"])
require.Equal(t, int64(125), result["reasoning_tokens"])
require.Len(t, result, 2)
})
t.Run("skips_null_and_invalid_metadata", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{Valid: false},
},
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: nil,
Valid: true,
},
},
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"tokens":42}`),
Valid: true,
},
},
}
result := aggregateTokenMetadata(tokens)
require.Equal(t, int64(42), result["tokens"])
require.Len(t, result, 1)
})
t.Run("skips_non_integer_values", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
// Float values fail json.Number.Int64(), so they
// are silently dropped.
RawMessage: json.RawMessage(`{"good":10,"fractional":1.5}`),
Valid: true,
},
},
}
result := aggregateTokenMetadata(tokens)
require.Equal(t, int64(10), result["good"])
_, hasFractional := result["fractional"]
require.False(t, hasFractional)
})
t.Run("skips_malformed_json", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`not json`),
Valid: true,
},
},
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"tokens":5}`),
Valid: true,
},
},
}
result := aggregateTokenMetadata(tokens)
// The malformed row is skipped, the valid one is counted.
require.Equal(t, int64(5), result["tokens"])
require.Len(t, result, 1)
})
t.Run("flattens_nested_objects", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{
"cache_read_tokens": 100,
"cache": {"creation_tokens": 40, "read_tokens": 60},
"reasoning_tokens": 50,
"tags": ["a", "b"]
}`),
Valid: true,
},
},
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{
"cache_read_tokens": 200,
"cache": {"creation_tokens": 10}
}`),
Valid: true,
},
},
}
result := aggregateTokenMetadata(tokens)
require.Equal(t, int64(300), result["cache_read_tokens"])
require.Equal(t, int64(50), result["reasoning_tokens"])
require.Equal(t, int64(50), result["cache.creation_tokens"])
require.Equal(t, int64(60), result["cache.read_tokens"])
// Arrays are skipped.
_, hasTags := result["tags"]
require.False(t, hasTags)
require.Len(t, result, 4)
})
t.Run("flattens_deeply_nested_objects", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{
"provider": {
"anthropic": {"cache_creation_tokens": 100, "cache_read_tokens": 200},
"openai": {"reasoning_tokens": 50}
},
"total": 500
}`),
Valid: true,
},
},
}
result := aggregateTokenMetadata(tokens)
require.Equal(t, int64(100), result["provider.anthropic.cache_creation_tokens"])
require.Equal(t, int64(200), result["provider.anthropic.cache_read_tokens"])
require.Equal(t, int64(50), result["provider.openai.reasoning_tokens"])
require.Equal(t, int64(500), result["total"])
require.Len(t, result, 4)
})
// Real-world provider metadata shapes from
// https://github.com/coder/aibridge/issues/150.
t.Run("aggregates_real_provider_metadata", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
// Anthropic-style: cache fields are top-level.
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 23490
}`),
Valid: true,
},
},
{
// OpenAI-style: cache fields are nested inside
// input_tokens_details.
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{
"input_tokens_details": {"cached_tokens": 11904}
}`),
Valid: true,
},
},
{
// Second Anthropic row to verify summing.
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{
"cache_creation_input_tokens": 500,
"cache_read_input_tokens": 10000
}`),
Valid: true,
},
},
}
result := aggregateTokenMetadata(tokens)
// Anthropic fields are summed across two rows.
require.Equal(t, int64(500), result["cache_creation_input_tokens"])
require.Equal(t, int64(33490), result["cache_read_input_tokens"])
// OpenAI nested field is flattened with dot notation.
require.Equal(t, int64(11904), result["input_tokens_details.cached_tokens"])
require.Len(t, result, 3)
})
t.Run("skips_string_boolean_null_values", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"tokens":10,"name":"test","enabled":true,"nothing":null}`),
Valid: true,
},
},
}
result := aggregateTokenMetadata(tokens)
require.Equal(t, int64(10), result["tokens"])
require.Len(t, result, 1)
})
}
func TestAggregateTokenUsage(t *testing.T) {
t.Parallel()
t.Run("empty_input", func(t *testing.T) {
t.Parallel()
result := aggregateTokenUsage(nil)
require.Equal(t, int64(0), result.InputTokens)
require.Equal(t, int64(0), result.OutputTokens)
require.Empty(t, result.Metadata)
})
t.Run("sums_tokens_and_metadata", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
InputTokens: 100,
OutputTokens: 50,
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"reasoning_tokens":20}`),
Valid: true,
},
},
{
ID: uuid.New(),
InputTokens: 200,
OutputTokens: 75,
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"reasoning_tokens":30}`),
Valid: true,
},
},
}
result := aggregateTokenUsage(tokens)
require.Equal(t, int64(300), result.InputTokens)
require.Equal(t, int64(125), result.OutputTokens)
require.Equal(t, int64(50), result.Metadata["reasoning_tokens"])
})
t.Run("handles_rows_without_metadata", func(t *testing.T) {
t.Parallel()
tokens := []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
InputTokens: 500,
OutputTokens: 200,
Metadata: pqtype.NullRawMessage{Valid: false},
},
}
result := aggregateTokenUsage(tokens)
require.Equal(t, int64(500), result.InputTokens)
require.Equal(t, int64(200), result.OutputTokens)
require.Empty(t, result.Metadata)
})
}
+5 -191
View File
@@ -5,7 +5,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
@@ -259,13 +258,11 @@ func TestAIBridgeInterception(t *testing.T) {
},
tokenUsages: []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: "resp-123",
InputTokens: 100,
OutputTokens: 200,
CacheReadInputTokens: 50,
CacheWriteInputTokens: 10,
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: "resp-123",
InputTokens: 100,
OutputTokens: 200,
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"cache":"hit"}`),
Valid: true,
@@ -415,8 +412,6 @@ func TestAIBridgeInterception(t *testing.T) {
require.Equal(t, tu.ProviderResponseID, result.TokenUsages[i].ProviderResponseID)
require.Equal(t, tu.InputTokens, result.TokenUsages[i].InputTokens)
require.Equal(t, tu.OutputTokens, result.TokenUsages[i].OutputTokens)
require.Equal(t, tu.CacheReadInputTokens, result.TokenUsages[i].CacheReadInputTokens)
require.Equal(t, tu.CacheWriteInputTokens, result.TokenUsages[i].CacheWriteInputTokens)
}
// Verify user prompts are converted correctly.
@@ -518,187 +513,6 @@ func TestChatQueuedMessage_ParsesUserContentParts(t *testing.T) {
require.Equal(t, "queued text", queued.Content[0].Text)
}
func TestChat_AllFieldsPopulated(t *testing.T) {
t.Parallel()
// Every field of database.Chat is set to a non-zero value so
// that the reflection check below catches any field that
// db2sdk.Chat forgets to populate. When someone adds a new
// field to codersdk.Chat, this test will fail until the
// converter is updated.
now := dbtime.Now()
input := database.Chat{
ID: uuid.New(),
OwnerID: uuid.New(),
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ParentChatID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
RootChatID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
LastModelConfigID: uuid.New(),
Title: "all-fields-test",
Status: database.ChatStatusRunning,
LastError: sql.NullString{String: "boom", Valid: true},
CreatedAt: now,
UpdatedAt: now,
Archived: true,
PinOrder: 1,
MCPServerIDs: []uuid.UUID{uuid.New()},
Labels: database.StringMap{"env": "prod"},
LastInjectedContext: pqtype.NullRawMessage{
// Use a context-file part to verify internal
// fields are not present (they are stripped at
// write time by chatd, not at read time).
RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`),
Valid: true,
},
}
// Only ChatID is needed here. This test checks that
// Chat.DiffStatus is non-nil, not that every DiffStatus
// field is populated — that would be a separate test for
// the ChatDiffStatus converter.
diffStatus := &database.ChatDiffStatus{
ChatID: input.ID,
}
fileRows := []database.GetChatFileMetadataByChatIDRow{
{
ID: uuid.New(),
OwnerID: input.OwnerID,
OrganizationID: uuid.New(),
Name: "test.png",
Mimetype: "image/png",
CreatedAt: now,
},
}
got := db2sdk.Chat(input, diffStatus, fileRows)
v := reflect.ValueOf(got)
typ := v.Type()
// HasUnread is populated by ChatRows (which joins the
// read-cursor query), not by Chat. Warnings is a transient
// field populated by handlers, not the converter. Both are
// expected to remain zero here.
skip := map[string]bool{"HasUnread": true, "Warnings": true}
for i := range typ.NumField() {
field := typ.Field(i)
if skip[field.Name] {
continue
}
require.False(t, v.Field(i).IsZero(),
"codersdk.Chat field %q is zero-valued — db2sdk.Chat may not be populating it",
field.Name,
)
}
}
func TestChat_FileMetadataConversion(t *testing.T) {
t.Parallel()
ownerID := uuid.New()
orgID := uuid.New()
fileID := uuid.New()
now := dbtime.Now()
chat := database.Chat{
ID: uuid.New(),
OwnerID: ownerID,
LastModelConfigID: uuid.New(),
Title: "file metadata test",
Status: database.ChatStatusWaiting,
CreatedAt: now,
UpdatedAt: now,
}
rows := []database.GetChatFileMetadataByChatIDRow{
{
ID: fileID,
OwnerID: ownerID,
OrganizationID: orgID,
Name: "screenshot.png",
Mimetype: "image/png",
CreatedAt: now,
},
}
result := db2sdk.Chat(chat, nil, rows)
require.Len(t, result.Files, 1)
f := result.Files[0]
require.Equal(t, fileID, f.ID)
require.Equal(t, ownerID, f.OwnerID, "OwnerID must be mapped from DB row")
require.Equal(t, orgID, f.OrganizationID, "OrganizationID must be mapped from DB row")
require.Equal(t, "screenshot.png", f.Name)
require.Equal(t, "image/png", f.MimeType)
require.Equal(t, now, f.CreatedAt)
// Verify JSON serialization uses snake_case for mime_type.
data, err := json.Marshal(f)
require.NoError(t, err)
require.Contains(t, string(data), `"mime_type"`)
require.NotContains(t, string(data), `"mimetype"`)
}
func TestChat_NilFilesOmitted(t *testing.T) {
t.Parallel()
chat := database.Chat{
ID: uuid.New(),
OwnerID: uuid.New(),
LastModelConfigID: uuid.New(),
Title: "no files",
Status: database.ChatStatusWaiting,
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
}
result := db2sdk.Chat(chat, nil, nil)
require.Empty(t, result.Files)
}
func TestChat_MultipleFiles(t *testing.T) {
t.Parallel()
now := dbtime.Now()
file1 := uuid.New()
file2 := uuid.New()
chat := database.Chat{
ID: uuid.New(),
OwnerID: uuid.New(),
LastModelConfigID: uuid.New(),
Title: "multi file test",
Status: database.ChatStatusWaiting,
CreatedAt: now,
UpdatedAt: now,
}
rows := []database.GetChatFileMetadataByChatIDRow{
{
ID: file1,
OwnerID: chat.OwnerID,
OrganizationID: uuid.New(),
Name: "a.png",
Mimetype: "image/png",
CreatedAt: now,
},
{
ID: file2,
OwnerID: chat.OwnerID,
OrganizationID: uuid.New(),
Name: "b.txt",
Mimetype: "text/plain",
CreatedAt: now,
},
}
result := db2sdk.Chat(chat, nil, rows)
require.Len(t, result.Files, 2)
require.Equal(t, "a.png", result.Files[0].Name)
require.Equal(t, "b.txt", result.Files[1].Name)
}
func TestChatQueuedMessage_MalformedContent(t *testing.T) {
t.Parallel()
+224 -265
View File
@@ -454,6 +454,7 @@ var (
rbac.ResourceOauth2App.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceOauth2AppSecret.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceAutomation.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
@@ -705,6 +706,7 @@ var (
DisplayName: "Chat Daemon",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceAutomation.Type: {policy.ActionRead},
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceDeploymentConfig.Type: {policy.ActionRead},
rbac.ResourceUser.Type: {policy.ActionReadPersonal},
@@ -1570,13 +1572,13 @@ func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UU
return q.db.AllUserIDs(ctx, includeSystem)
}
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, id)
if err != nil {
return nil, err
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return nil, err
return err
}
return q.db.ArchiveChatByID(ctx, id)
}
@@ -1627,13 +1629,6 @@ func (q *querier) BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg datab
return q.db.BatchUpdateWorkspaceNextStartAt(ctx, arg)
}
func (q *querier) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return err
}
return q.db.BatchUpsertConnectionLogs(ctx, arg)
}
func (q *querier) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceNotificationMessage); err != nil {
return 0, err
@@ -1738,6 +1733,28 @@ func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLog
return q.db.CountAuthorizedAuditLogs(ctx, arg, prep)
}
func (q *querier) CountAutomationChatCreatesInWindow(ctx context.Context, arg database.CountAutomationChatCreatesInWindowParams) (int64, error) {
automation, err := q.db.GetAutomationByID(ctx, arg.AutomationID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, automation); err != nil {
return 0, err
}
return q.db.CountAutomationChatCreatesInWindow(ctx, arg)
}
func (q *querier) CountAutomationMessagesInWindow(ctx context.Context, arg database.CountAutomationMessagesInWindowParams) (int64, error) {
automation, err := q.db.GetAutomationByID(ctx, arg.AutomationID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, automation); err != nil {
return 0, err
}
return q.db.CountAutomationMessagesInWindow(ctx, arg)
}
func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
// Just like the actual query, shortcut if the user is an owner.
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
@@ -1849,6 +1866,25 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteAutomationByID(ctx context.Context, id uuid.UUID) error {
return fetchAndExec(q.log, q.auth, policy.ActionDelete, q.db.GetAutomationByID, q.db.DeleteAutomationByID)(ctx, id)
}
func (q *querier) DeleteAutomationTriggerByID(ctx context.Context, id uuid.UUID) error {
trigger, err := q.db.GetAutomationTriggerByID(ctx, id)
if err != nil {
return err
}
automation, err := q.db.GetAutomationByID(ctx, trigger.AutomationID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, automation); err != nil {
return err
}
return q.db.DeleteAutomationTriggerByID(ctx, id)
}
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -2144,17 +2180,6 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
}
func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return err
}
return q.db.DeleteUserChatProviderKey(ctx, arg)
}
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// First get the secret to check ownership
secret, err := q.GetUserSecret(ctx, id)
@@ -2404,6 +2429,13 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
return q.db.GetActiveAISeatCount(ctx)
}
func (q *querier) GetActiveCronTriggers(ctx context.Context) ([]database.GetActiveCronTriggersRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAutomation.All()); err != nil {
return nil, err
}
return q.db.GetActiveCronTriggers(ctx)
}
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
return nil, err
@@ -2495,6 +2527,55 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
return q.db.GetAuthorizationUserRoles(ctx, userID)
}
func (q *querier) GetAutomationByID(ctx context.Context, id uuid.UUID) (database.Automation, error) {
return fetch(q.log, q.auth, q.db.GetAutomationByID)(ctx, id)
}
func (q *querier) GetAutomationEvents(ctx context.Context, arg database.GetAutomationEventsParams) ([]database.AutomationEvent, error) {
automation, err := q.db.GetAutomationByID(ctx, arg.AutomationID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, automation); err != nil {
return nil, err
}
return q.db.GetAutomationEvents(ctx, arg)
}
func (q *querier) GetAutomationTriggerByID(ctx context.Context, id uuid.UUID) (database.AutomationTrigger, error) {
trigger, err := q.db.GetAutomationTriggerByID(ctx, id)
if err != nil {
return database.AutomationTrigger{}, err
}
automation, err := q.db.GetAutomationByID(ctx, trigger.AutomationID)
if err != nil {
return database.AutomationTrigger{}, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, automation); err != nil {
return database.AutomationTrigger{}, err
}
return trigger, nil
}
func (q *querier) GetAutomationTriggersByAutomationID(ctx context.Context, automationID uuid.UUID) ([]database.AutomationTrigger, error) {
automation, err := q.db.GetAutomationByID(ctx, automationID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, automation); err != nil {
return nil, err
}
return q.db.GetAutomationTriggersByAutomationID(ctx, automationID)
}
func (q *querier) GetAutomations(ctx context.Context, arg database.GetAutomationsParams) ([]database.Automation, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAutomation.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.GetAuthorizedAutomations(ctx, arg, prep)
}
func (q *querier) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
return fetch(q.log, q.auth, q.db.GetChatByID)(ctx, id)
}
@@ -2583,10 +2664,6 @@ func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.C
return file, nil
}
func (q *querier) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatFileMetadataByChatID)(ctx, chatID)
}
func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
files, err := q.db.GetChatFilesByIDs(ctx, ids)
if err != nil {
@@ -2600,18 +2677,6 @@ func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]dat
return files, nil
}
func (q *querier) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
// The include-default-system-prompt flag is a deployment-wide setting read
// during chat creation by every authenticated user, so no RBAC policy
// check is needed. We still verify that a valid actor exists in the
// context to ensure this is never callable by an unauthenticated or
// system-internal path without an explicit actor.
if _, ok := ActorFromContext(ctx); !ok {
return false, ErrNoActor
}
return q.db.GetChatIncludeDefaultSystemPrompt(ctx)
}
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
// ChatMessages are authorized through their parent Chat.
// We need to fetch the message first to get its chat_id.
@@ -2636,14 +2701,6 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC
return q.db.GetChatMessagesByChatID(ctx, arg)
}
func (q *querier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
_, err := q.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesByChatIDAscPaginated(ctx, arg)
}
func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
_, err := q.GetChatByID(ctx, arg.ChatID)
if err != nil {
@@ -2716,18 +2773,6 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
return q.db.GetChatSystemPrompt(ctx)
}
func (q *querier) GetChatSystemPromptConfig(ctx context.Context) (database.GetChatSystemPromptConfigRow, error) {
// The system prompt configuration is a deployment-wide setting read during
// chat creation by every authenticated user, so no RBAC policy check is
// needed. We still verify that a valid actor exists in the context to
// ensure this is never callable by an unauthenticated or system-internal
// path without an explicit actor.
if _, ok := ActorFromContext(ctx); !ok {
return database.GetChatSystemPromptConfigRow{}, ErrNoActor
}
return q.db.GetChatSystemPromptConfig(ctx)
}
// GetChatTemplateAllowlist requires deployment-config read permission,
// unlike the peer getters (GetChatDesktopEnabled, etc.) which only
// check actor presence. The allowlist is admin-configuration that
@@ -2770,7 +2815,7 @@ func (q *querier) GetChatWorkspaceTTL(ctx context.Context) (string, error) {
return q.db.GetChatWorkspaceTTL(ctx)
}
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.GetChatsRow, error) {
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
@@ -2778,10 +2823,6 @@ func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]
return q.db.GetAuthorizedChats(ctx, arg, prep)
}
func (q *querier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByWorkspaceIDs)(ctx, ids)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
// Just like with the audit logs query, shortcut if the user is an owner.
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
@@ -2833,15 +2874,7 @@ func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
}
func (q *querier) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
// Any user who can read chat resources can read the default
// model config, since model resolution is required to create
// a chat. This avoids gating on ResourceDeploymentConfig
// which regular members lack.
act, ok := ActorFromContext(ctx)
if !ok {
return database.ChatModelConfig{}, ErrNoActor
}
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(act.ID)); err != nil {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.GetDefaultChatModelConfig(ctx)
@@ -3657,18 +3690,18 @@ func (q *querier) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database
return q.db.GetTailnetPeers(ctx, id)
}
func (q *querier) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
func (q *querier) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetTailnetTunnelPeerBindingsBatch(ctx, ids)
return q.db.GetTailnetTunnelPeerBindings(ctx, srcID)
}
func (q *querier) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
func (q *querier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetTailnetTunnelPeerIDsBatch(ctx, ids)
return q.db.GetTailnetTunnelPeerIDs(ctx, srcID)
}
func (q *querier) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, error) {
@@ -3987,13 +4020,6 @@ func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License,
return q.db.GetUnexpiredLicenses(ctx)
}
func (q *querier) GetUserAISeatStates(ctx context.Context, userIDs []uuid.UUID) ([]uuid.UUID, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceUser); err != nil {
return nil, err
}
return q.db.GetUserAISeatStates(ctx, userIDs)
}
func (q *querier) GetUserActivityInsights(ctx context.Context, arg database.GetUserActivityInsightsParams) ([]database.GetUserActivityInsightsRow, error) {
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
if err := q.authorizeContext(ctx, policy.ActionViewInsights, rbac.ResourceTemplate); err != nil {
@@ -4046,17 +4072,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
return q.db.GetUserChatCustomPrompt(ctx, userID)
}
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
return nil, err
}
return q.db.GetUserChatProviderKeys(ctx, userID)
}
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
return 0, err
@@ -4801,6 +4816,32 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
}
func (q *querier) InsertAutomation(ctx context.Context, arg database.InsertAutomationParams) (database.Automation, error) {
return insert(q.log, q.auth, rbac.ResourceAutomation.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertAutomation)(ctx, arg)
}
func (q *querier) InsertAutomationEvent(ctx context.Context, arg database.InsertAutomationEventParams) (database.AutomationEvent, error) {
automation, err := q.db.GetAutomationByID(ctx, arg.AutomationID)
if err != nil {
return database.AutomationEvent{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, automation); err != nil {
return database.AutomationEvent{}, err
}
return q.db.InsertAutomationEvent(ctx, arg)
}
func (q *querier) InsertAutomationTrigger(ctx context.Context, arg database.InsertAutomationTriggerParams) (database.AutomationTrigger, error) {
automation, err := q.db.GetAutomationByID(ctx, arg.AutomationID)
if err != nil {
return database.AutomationTrigger{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, automation); err != nil {
return database.AutomationTrigger{}, err
}
return q.db.InsertAutomationTrigger(ctx, arg)
}
func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
}
@@ -5397,25 +5438,6 @@ func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg datab
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
}
func (q *querier) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.LinkChatFiles(ctx, arg)
}
func (q *querier) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.ListAuthorizedAIBridgeClients(ctx, arg, prep)
}
func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.ListAIBridgeInterceptionsRow, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -5431,13 +5453,6 @@ func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Contex
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
}
func (q *querier) ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeModelThought, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return nil, err
}
return q.db.ListAIBridgeModelThoughtsByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -5446,14 +5461,6 @@ func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBri
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
}
func (q *querier) ListAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams) ([]database.ListAIBridgeSessionThreadsRow, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.ListAuthorizedAIBridgeSessionThreads(ctx, arg, prep)
}
func (q *querier) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -5591,17 +5598,6 @@ func (q *querier) PaginatedOrganizationMembers(ctx context.Context, arg database
return q.db.PaginatedOrganizationMembers(ctx, arg)
}
func (q *querier) PinChatByID(ctx context.Context, id uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, id)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.PinChatByID(ctx, id)
}
func (q *querier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
@@ -5613,6 +5609,13 @@ func (q *querier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (d
return q.db.PopNextQueuedMessage(ctx, chatID)
}
func (q *querier) PurgeOldAutomationEvents(ctx context.Context) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAutomation); err != nil {
return err
}
return q.db.PurgeOldAutomationEvents(ctx)
}
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
template, err := q.db.GetTemplateByID(ctx, templateID)
if err != nil {
@@ -5693,13 +5696,13 @@ func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
return q.db.TryAcquireLock(ctx, id)
}
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, id)
if err != nil {
return nil, err
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return nil, err
return err
}
return q.db.UnarchiveChatByID(ctx, id)
}
@@ -5727,17 +5730,6 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
}
func (q *querier) UnpinChatByID(ctx context.Context, id uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, id)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.UnpinChatByID(ctx, id)
}
func (q *querier) UnsetDefaultChatModelConfigs(ctx context.Context) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -5759,16 +5751,60 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
}
func (q *querier) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
func (q *querier) UpdateAutomation(ctx context.Context, arg database.UpdateAutomationParams) (database.Automation, error) {
automation, err := q.db.GetAutomationByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
return database.Automation{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
if err := q.authorizeContext(ctx, policy.ActionUpdate, automation); err != nil {
return database.Automation{}, err
}
return q.db.UpdateAutomation(ctx, arg)
}
return q.db.UpdateChatBuildAgentBinding(ctx, arg)
func (q *querier) UpdateAutomationTrigger(ctx context.Context, arg database.UpdateAutomationTriggerParams) (database.AutomationTrigger, error) {
trigger, err := q.db.GetAutomationTriggerByID(ctx, arg.ID)
if err != nil {
return database.AutomationTrigger{}, err
}
automation, err := q.db.GetAutomationByID(ctx, trigger.AutomationID)
if err != nil {
return database.AutomationTrigger{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, automation); err != nil {
return database.AutomationTrigger{}, err
}
return q.db.UpdateAutomationTrigger(ctx, arg)
}
func (q *querier) UpdateAutomationTriggerLastTriggeredAt(ctx context.Context, arg database.UpdateAutomationTriggerLastTriggeredAtParams) error {
trigger, err := q.db.GetAutomationTriggerByID(ctx, arg.ID)
if err != nil {
return err
}
automation, err := q.db.GetAutomationByID(ctx, trigger.AutomationID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, automation); err != nil {
return err
}
return q.db.UpdateAutomationTriggerLastTriggeredAt(ctx, arg)
}
func (q *querier) UpdateAutomationTriggerWebhookSecret(ctx context.Context, arg database.UpdateAutomationTriggerWebhookSecretParams) (database.AutomationTrigger, error) {
trigger, err := q.db.GetAutomationTriggerByID(ctx, arg.ID)
if err != nil {
return database.AutomationTrigger{}, err
}
automation, err := q.db.GetAutomationByID(ctx, trigger.AutomationID)
if err != nil {
return database.AutomationTrigger{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, automation); err != nil {
return database.AutomationTrigger{}, err
}
return q.db.UpdateAutomationTriggerWebhookSecret(ctx, arg)
}
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
@@ -5782,15 +5818,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
return q.db.UpdateChatByID(ctx, arg)
}
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
// The batch heartbeat is a system-level operation filtered by
// worker_id. Authorization is enforced by the AsChatd context
// at the call site rather than per-row, because checking each
// row individually would defeat the purpose of batching.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return nil, err
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeats(ctx, arg)
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeat(ctx, arg)
}
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
@@ -5804,39 +5840,6 @@ func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateC
return q.db.UpdateChatLabelsByID(ctx, arg)
}
func (q *querier) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatLastInjectedContext(ctx, arg)
}
func (q *querier) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatLastModelConfigByID(ctx, arg)
}
func (q *querier) UpdateChatLastReadMessageID(ctx context.Context, arg database.UpdateChatLastReadMessageIDParams) error {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.UpdateChatLastReadMessageID(ctx, arg)
}
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
@@ -5871,17 +5874,6 @@ func (q *querier) UpdateChatModelConfig(ctx context.Context, arg database.Update
return q.db.UpdateChatModelConfig(ctx, arg)
}
func (q *querier) UpdateChatPinOrder(ctx context.Context, arg database.UpdateChatPinOrderParams) error {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.UpdateChatPinOrder(ctx, arg)
}
func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
@@ -5902,18 +5894,7 @@ func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatS
return q.db.UpdateChatStatus(ctx, arg)
}
func (q *querier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatStatusPreserveUpdatedAt(ctx, arg)
}
func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
@@ -5922,7 +5903,15 @@ func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.U
return database.Chat{}, err
}
return q.db.UpdateChatWorkspaceBinding(ctx, arg)
// UpdateChatWorkspace is manually implemented for chat tables and may not be
// present on every wrapped store interface yet.
chatWorkspaceUpdater, ok := q.db.(interface {
UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error)
})
if !ok {
return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store")
}
return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg)
}
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
@@ -6498,17 +6487,6 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U
return q.db.UpdateUserChatCustomPrompt(ctx, arg)
}
func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return database.UserChatProviderKey{}, err
}
return q.db.UpdateUserChatProviderKey(ctx, arg)
}
func (q *querier) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
return deleteQ(q.log, q.auth, q.db.GetUserByID, q.db.UpdateUserDeletedByID)(ctx, id)
}
@@ -7037,13 +7015,6 @@ func (q *querier) UpsertChatDiffStatusReference(ctx context.Context, arg databas
return q.db.UpsertChatDiffStatusReference(ctx, arg)
}
func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
}
func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -7087,6 +7058,13 @@ func (q *querier) UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl strin
return q.db.UpsertChatWorkspaceTTL(ctx, workspaceTtl)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
}
return q.db.UpsertConnectionLog(ctx, arg)
}
func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -7229,17 +7207,6 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
return q.db.UpsertTemplateUsageStats(ctx)
}
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return database.UserChatProviderKey{}, err
}
return q.db.UpsertUserChatProviderKey(ctx, arg)
}
func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -7388,14 +7355,6 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
return q.ListAIBridgeModels(ctx, arg)
}
func (q *querier) ListAuthorizedAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams, _ rbac.PreparedAuthorized) ([]string, error) {
// TODO: Delete this function, all ListAIBridgeClients should be
// authorized. For now just call ListAIBridgeClients on the authz
// querier. This cannot be deleted for now because it's included in
// the database.Store interface, so dbauthz needs to implement it.
return q.ListAIBridgeClients(ctx, arg)
}
func (q *querier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
}
@@ -7404,10 +7363,10 @@ func (q *querier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg datab
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
}
func (q *querier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionThreadsRow, error) {
return q.db.ListAuthorizedAIBridgeSessionThreads(ctx, arg, prepared)
}
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.GetChatsRow, error) {
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
return q.GetChats(ctx, arg)
}
func (q *querier) GetAuthorizedAutomations(ctx context.Context, arg database.GetAutomationsParams, _ rbac.PreparedAuthorized) ([]database.Automation, error) {
return q.GetAutomations(ctx, arg)
}

Some files were not shown because too many files have changed in this diff Show More