Compare commits
130 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 119ddab975 | |||
| 53482adc2d | |||
| aa0e288b88 | |||
| 1c4a9ed745 | |||
| 5b6b7719df | |||
| 990c006f28 | |||
| 0cb942aab2 | |||
| b0a6802d12 | |||
| 8d08885792 | |||
| f68161350a | |||
| 9ac67a5253 | |||
| 7d0a0c6495 | |||
| 17dec2a70f | |||
| f796f3645f | |||
| d5ed51a190 | |||
| 796e8e4e18 | |||
| 5b28548d1c | |||
| b5da77ff55 | |||
| cc143c8990 | |||
| 3def04a3ee | |||
| d4a9c63e91 | |||
| 00217fefa5 | |||
| fce05d0428 | |||
| 2ebc076b9e | |||
| e71dc6dd4d | |||
| ca3ae3643d | |||
| ed5c06f039 | |||
| 1221622bf0 | |||
| aa5ec0bfcc | |||
| 16add93908 | |||
| fe13fd065c | |||
| fb788530b3 | |||
| f4dc8f6b11 | |||
| bbeff0d4b5 | |||
| 78fa8094cc | |||
| a85e00eed0 | |||
| da50a34414 | |||
| cd784c755a | |||
| eb4860aac3 | |||
| 07fbe8ca7d | |||
| ba0a64d483 | |||
| 7757cd8e08 | |||
| fc1e0beb3b | |||
| 3a4a0b7270 | |||
| 11c1afb5e9 | |||
| be2e641162 | |||
| 83e2699914 | |||
| 515ba209fd | |||
| d15bfc2cb0 | |||
| 308053b0e4 | |||
| 7c29355e84 | |||
| 7c048d8eb4 | |||
| e81275a91c | |||
| 4a363b0d85 | |||
| 7dc81bdef1 | |||
| ee855f9618 | |||
| b1c42bb630 | |||
| c048a4093e | |||
| 1cc23a3144 | |||
| dee5ec51c0 | |||
| e3c59c00cd | |||
| ba734f8b10 | |||
| 28062862a0 | |||
| 129e3509a3 | |||
| 53a1b6d67e | |||
| 8c8b307b97 | |||
| 12f87acad6 | |||
| 7198f9040d | |||
| ddafdbcbce | |||
| 196dc51edf | |||
| 2a51687ff3 | |||
| 19e44f4136 | |||
| 2ea89e1f1b | |||
| faa5db0cf0 | |||
| 3758b02595 | |||
| 7861fcf1f6 | |||
| 4b52656958 | |||
| f5b98aa12d | |||
| 7ddde0887b | |||
| bec426b24f | |||
| d6df78c9b9 | |||
| 19390a5841 | |||
| 2d03f7fd3d | |||
| 153a66b579 | |||
| 5cba59af79 | |||
| 1d16ff1ca6 | |||
| b86161e0a6 | |||
| a164d508cf | |||
| b9f140e53e | |||
| 7f7b13f0ab | |||
| e2bbd12137 | |||
| e769d1bd7d | |||
| cccb680ec2 | |||
| e8fb418820 | |||
| 2c5e003c91 | |||
| f44a8994da | |||
| 84b94a8376 | |||
| 2a990ce758 | |||
| c86f1288f1 | |||
| 9440adf435 | |||
| 755e8be5ad | |||
| c9e335c453 | |||
| 2d1f35f8a6 | |||
| b0036af57b | |||
| 2953245862 | |||
| 5d07014f9f | |||
| 002e88fefc | |||
| bbf3fbc830 | |||
| 9fa103929a | |||
| acd2ff63a7 | |||
| e3e17e15f7 | |||
| af678606fc | |||
| 3190406de3 | |||
| 3ce82bb885 | |||
| 348a3bd693 | |||
| 75f1503b41 | |||
| c33cd19a05 | |||
| adcea865c7 | |||
| 5e3bccd96c | |||
| 3950947c58 | |||
| b3d5b8d13c | |||
| a00afe4b5a | |||
| a5cc579453 | |||
| ef3aade647 | |||
| 3cc31de57a | |||
| d2c308e481 | |||
| 953c3bdc0f | |||
| ca879ffae6 | |||
| 0880a4685b | |||
| 3f8e3007d8 |
@@ -111,8 +111,8 @@ Tier 2 file filters:
|
||||
|
||||
- **Modernization Reviewer**: one instance per language present in the diff. Filter by extension:
|
||||
- Go: `*.go` — reference `.claude/docs/GO.md` before reviewing.
|
||||
- TypeScript: `*.ts` `*.tsx`
|
||||
- React: `*.tsx` `*.jsx`
|
||||
- TypeScript: `*.ts` `*.tsx`: reference `.agents/skills/deep-review/references/typescript.md` before reviewing.
|
||||
- React: `*.tsx` `*.jsx`: reference `.agents/skills/deep-review/references/react.md` before reviewing.
|
||||
|
||||
`.tsx` files match both TypeScript and React filters. Spawn both instances when the diff contains `.tsx` changes — TS covers language-level patterns; React covers component and hooks patterns. Before spawning, verify each instance's filter produces a non-empty diff. Skip instances whose filtered diff is empty.
|
||||
|
||||
@@ -155,9 +155,11 @@ File scope: {filter from step 2}.
|
||||
Output file: {REVIEW_DIR}/{role-name}.md
|
||||
```
|
||||
|
||||
For the Modernization Reviewer (Go), add after the methodology line:
|
||||
For Modernization Reviewer instances, add the language reference after the methodology line:
|
||||
|
||||
> Read `.claude/docs/GO.md` as your Go language reference before reviewing.
|
||||
- **Go:** `Read .claude/docs/GO.md as your Go language reference before reviewing.`
|
||||
- **TypeScript:** `Read .agents/skills/deep-review/references/typescript.md as your TypeScript language reference before reviewing.`
|
||||
- **React:** `Read .agents/skills/deep-review/references/react.md as your React language reference before reviewing.`
|
||||
|
||||
For re-reviews, append to both Tier 1 and Tier 2 prompts:
|
||||
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
# Modern React (18–19.2) + Compiler 1.0 — Reference
|
||||
|
||||
Reference for writing idiomatic React. Covers what changed, what it replaced, and what to reach for. Includes React Compiler patterns — what the compiler handles automatically, what it changes semantically, and how to verify its behavior empirically. Scope: client-side SPA patterns only. Server Components, `use server`, and `use client` directives are framework-specific and omitted. Check the project's React version and compiler config before reaching for newer APIs.
|
||||
|
||||
## How modern React thinks differently
|
||||
|
||||
**Concurrent rendering** (18): React can now pause, interrupt, and resume renders. This is the foundation everything else builds on. Most existing code "just works," but components that produce side effects during render (mutations, subscriptions, network calls in the render body) are unsafe and will misbehave. Concurrent features are opt-in — they only activate when you use a concurrent API like `startTransition` or `useDeferredValue`.
|
||||
|
||||
**Urgent vs. non-urgent updates** (18): The `startTransition` / `useTransition` API introduces a formal split between updates that must feel immediate (typing, clicking) and updates that can be interrupted (filtering a large list, navigating to a new screen). Non-urgent updates yield to urgent ones mid-render. Use this instead of `setTimeout` or manual debounce when you want the UI to stay responsive during expensive re-renders.
|
||||
|
||||
**Actions** (19): Async functions passed to `startTransition` are called "Actions." They automatically manage pending state, error handling, and optimistic updates as a unit. The `useActionState` hook and `<form action={fn}>` prop are built on this. The pattern replaces the hand-rolled `isPending/setIsPending` + `try/catch` + `setError` boilerplate that was previously necessary for every data mutation.
|
||||
|
||||
**Automatic batching** (18): State updates are now batched everywhere — inside `setTimeout`, `Promise.then`, native event handlers, etc. Previously batching only happened inside React-managed event handlers. If you genuinely need a synchronous flush, use `flushSync`.
|
||||
|
||||
**Automatic memoization** (Compiler 1.0): React Compiler is a build-time Babel plugin that automatically inserts memoization into components and hooks. It replaces manual `useMemo`, `useCallback`, and `React.memo` — including conditional memoization and memoization after early returns, which manual APIs cannot express. The compiler only processes components and hooks, not standalone functions. It understands data flow and mutability through its own HIR (High-level Intermediate Representation), so it can memoize more granularly than a human would. Projects adopt it incrementally — typically via path-based Babel overrides or the `"use memo"` directive. Components that violate the Rules of React are silently skipped (no build error), so the automated lint tools that check compiler compatibility matter.
|
||||
|
||||
## Replace these patterns
|
||||
|
||||
The left column reflects patterns common before React 18/19. Write the right column instead. The "Since" column tells you the minimum React version required.
|
||||
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ----------------------------------------------------------------- | ------------------------------------------------------------------------------ | ----- |
|
||||
| `ReactDOM.render(<App />, el)` | `createRoot(el).render(<App />)` | 18 |
|
||||
| `ReactDOM.hydrate(<App />, el)` | `hydrateRoot(el, <App />)` | 18 |
|
||||
| `ReactDOM.unmountComponentAtNode(el)` | `root.unmount()` | 18 |
|
||||
| `ReactDOM.findDOMNode(this)` | DOM ref: `const ref = useRef(); ref.current` | 18 |
|
||||
| `<Context.Provider value={v}>` | `<Context value={v}>` | 19 |
|
||||
| `React.forwardRef((props, ref) => ...)` | `function Comp({ ref, ...props }) { ... }` (ref as a regular prop) | 19 |
|
||||
| String ref `ref="input"` in class components | Callback ref or `createRef()` | 19 |
|
||||
| `Heading.propTypes = { ... }` | TypeScript / ES6 type annotations | 19 |
|
||||
| `Component.defaultProps = { ... }` on function components | ES6 default parameters `({ text = 'Hi' })` | 19 |
|
||||
| Legacy Context: `contextTypes` + `getChildContext` | `React.createContext()` + `contextType` | 19 |
|
||||
| `import { act } from 'react-dom/test-utils'` | `import { act } from 'react'` | 19 |
|
||||
| `import ShallowRenderer from 'react-test-renderer/shallow'` | `import ShallowRenderer from 'react-shallow-renderer'` | 19 |
|
||||
| Manual `isPending` state around async calls | `const [isPending, startTransition] = useTransition()` | 18 |
|
||||
| Manual optimistic state + revert logic | `useOptimistic(currentValue)` | 19 |
|
||||
| `useEffect` to subscribe to external stores | `useSyncExternalStore(subscribe, getSnapshot)` | 18 |
|
||||
| Hand-rolled unique ID (counter, random, index) | `useId()` — SSR-safe, hydration-safe | 18 |
|
||||
| `useEffect` to inject `<title>` or `<meta>` / `react-helmet` | Render `<title>`, `<meta>`, `<link>` directly in components; React hoists them | 19 |
|
||||
| `ReactDOM.useFormState(action, initial)` (Canary name) | `useActionState(action, initial)` | 19 |
|
||||
| `useReducer<React.Reducer<State, Action>>(reducer)` | `useReducer(reducer)` — infers from the reducer function | 19 |
|
||||
| `<div ref={current => (instance = current)} />` (implicit return) | `<div ref={current => { instance = current }} />` (explicit block body) | 19 |
|
||||
| `useRef<T>()` with no argument | `useRef<T>(undefined)` or `useRef<T \| null>(null)` — argument is now required | 19 |
|
||||
| `MutableRefObject<T>` type annotation | `RefObject<T>` — all refs are mutable now; `MutableRefObject` is deprecated | 19 |
|
||||
| `React.createFactory('button')` | `<button />` JSX | 19 |
|
||||
| `useMemo(() => expr, [deps])` in compiled components | `const val = expr;` — compiler memoizes automatically | C 1.0 |
|
||||
| `useCallback(fn, [deps])` in compiled components | `const fn = () => { ... };` — compiler memoizes automatically | C 1.0 |
|
||||
| `React.memo(Component)` in compiled components | Plain component — compiler skips re-render when props are unchanged | C 1.0 |
|
||||
| `eslint-plugin-react-compiler` (standalone) | `eslint-plugin-react-hooks@latest` (compiler rules merged into recommended) | C 1.0 |
|
||||
| `useRef` + `useLayoutEffect` for stable callbacks | `useEffectEvent(fn)` — compiler handles both, but `useEffectEvent` is clearer | 19.2 |
|
||||
|
||||
## New capabilities
|
||||
|
||||
These enable things that weren't practical before. Reach for them in the described situations.
|
||||
|
||||
| What | Since | When to use it |
|
||||
| -------------------------------------------------------------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `useTransition()` / `startTransition()` | 18 | Mark a state update as non-urgent so React can interrupt it to handle clicks or keystrokes. The `isPending` boolean lets you show a loading indicator without blocking the UI. |
|
||||
| `useDeferredValue(value, initialValue?)` | 18 / 19 | Defer re-rendering a slow subtree: pass the deferred value as a prop, wrap the expensive child in `memo`. Unlike debounce, uses no fixed timeout — renders as soon as the browser is idle. The `initialValue` arg (19) avoids a flash on first render. |
|
||||
| `useId()` | 18 | Generate a stable, SSR-consistent ID for accessibility attributes (`htmlFor`, `aria-describedby`). Do not use for list keys. |
|
||||
| `useSyncExternalStore(subscribe, getSnapshot, getServerSnapshot?)` | 18 | Subscribe to external (non-React) state stores safely under concurrent rendering. Preferred over `useEffect`-based subscriptions in libraries. |
|
||||
| `useActionState(action, initialState)` | 19 | Manage an async mutation: returns `[state, wrappedAction, isPending]`. Handles pending, result, and error state as a unit. Replaces the manual `isPending` + `try/catch` + `setError` pattern. |
|
||||
| `useOptimistic(currentValue)` | 19 | Show a speculative value while an async Action is in flight. Returns `[optimisticValue, setOptimistic]`. React automatically reverts to `currentValue` when the transition settles. |
|
||||
| `use(promiseOrContext)` | 19 | Read a promise or Context value inside a component or custom hook. Unlike hooks, `use` can be called conditionally (after early returns). Promises must come from a cache — do not create them during render. |
|
||||
| `useFormStatus()` (from `react-dom`) | 19 | Read `{ pending, data, method, action }` of the nearest parent `<form>` Action. Works across component boundaries without prop drilling — useful for submit buttons inside design-system components. |
|
||||
| `useEffectEvent(fn)` | 19.2 | Extract a non-reactive callback from an effect. The function sees the latest props/state without being listed in deps, and is never stale. Replaces the `useRef`-and-mutate-in-layout-effect workaround for stable event-like callbacks. The compiler has built-in knowledge of this hook and correctly prunes its return value from effect dependency arrays. Both `useEffectEvent` and the old ref workaround compile cleanly; `useEffectEvent` is preferred for clarity. |
|
||||
| `<Activity>` | 19.2 | Hide part of the UI while preserving its state and DOM. React deprioritizes updates to hidden content. Use via framework APIs for route prerendering or tab preservation — not a direct replacement for CSS `visibility`. |
|
||||
| `captureOwnerStack()` | 19.1 | Dev-only API that returns a string showing which components are responsible for rendering the current component (owner stack, not call stack). Useful for custom error overlays. Returns `null` in production. |
|
||||
| `<form action={fn}>` | 19 | Pass an async function as a form's `action` prop. React handles submission, pending state, and automatic form reset on success. Works with `useActionState` and `useFormStatus`. |
|
||||
| Ref cleanup function | 19 | Return a cleanup function from a ref callback: `ref={el => { ...; return () => cleanup(); }}`. React calls it on unmount. Replaces the pattern of checking `el === null` in the callback. |
|
||||
| `<link rel="stylesheet" precedence="default">` | 19 | Declare a stylesheet next to the component that needs it. React deduplicates and inserts it in the correct order before revealing Suspense content. |
|
||||
| `preinit`, `preload`, `prefetchDNS`, `preconnect` (from `react-dom`) | 19 | Imperatively hint the browser to load resources early. Call from render or event handlers. React deduplicates hints across the component tree. |
|
||||
| React Compiler (`babel-plugin-react-compiler`) | C 1.0 | Build-time automatic memoization for components and hooks. Install, add to Babel/Vite pipeline. Projects typically start with path-based overrides to compile a subset of files. |
|
||||
| `"use memo"` directive | C 1.0 | Opt a single function into compilation when using `compilationMode: 'annotation'`. Place at the start of the function body. Module-level `"use memo"` at the top of a file compiles all functions in that file. |
|
||||
| `"use no memo"` directive | C 1.0 | Temporary escape hatch — skip compilation for a specific component or hook that causes a runtime regression. Not a permanent solution. Place at the start of the function body. |
|
||||
| Compiler-powered ESLint rules | C 1.0 | Rules for purity, refs, set-state-in-render, immutability, etc. now ship in `eslint-plugin-react-hooks` recommended preset. Surface Rules-of-React violations even without the compiler installed. Note: some projects use Biome instead — check project lint config. |
|
||||
|
||||
## Key APIs
|
||||
|
||||
### `useTransition` and `startTransition` (18)
|
||||
|
||||
`useTransition` returns `[isPending, startTransition]`. Wrap any state update that is not directly tied to the user's current gesture inside `startTransition`. React will render the old UI while computing the new one, and `isPending` is `true` during that window.
|
||||
|
||||
In React 19, `startTransition` can accept an async function (an "Action"). React sets `isPending` to `true` for the entire duration of the async work, not just during the synchronous part.
|
||||
|
||||
```tsx
|
||||
// 18: synchronous transition
|
||||
const [isPending, startTransition] = useTransition();
|
||||
startTransition(() => setQuery(input));
|
||||
|
||||
// 19: async Action — isPending stays true until the await settles
|
||||
startTransition(async () => {
|
||||
const err = await updateName(name);
|
||||
if (err) setError(err);
|
||||
});
|
||||
```
|
||||
|
||||
Use `startTransition` (the module-level export) when you cannot use the hook (outside a component, in a router callback, etc.).
|
||||
|
||||
### `useDeferredValue` (18 / 19)
|
||||
|
||||
Creates a "lagging" copy of a value. Pass it to a memoized, expensive component so that React can render the stale UI while computing the updated one.
|
||||
|
||||
```tsx
|
||||
// 19: initialValue shows '' on first render; avoids loading flash
|
||||
const deferred = useDeferredValue(searchQuery, "");
|
||||
return <Results query={deferred} />; // Results wrapped in memo
|
||||
```
|
||||
|
||||
`deferred !== searchQuery` while the deferred render is in progress — use this to show a "stale" indicator.
|
||||
|
||||
### `useActionState` (19)
|
||||
|
||||
Replaces the `useState` + `isPending` + `try/catch` + `setError` boilerplate for any async operation that can be retried or submitted as a form.
|
||||
|
||||
```tsx
|
||||
const [error, submitAction, isPending] = useActionState(
|
||||
async (prevState, formData) => {
|
||||
const err = await updateName(formData.get("name"));
|
||||
if (err) return err; // returned value becomes next state
|
||||
redirect("/profile");
|
||||
return null;
|
||||
},
|
||||
null, // initialState
|
||||
);
|
||||
|
||||
// Use submitAction as the form's action prop or call it directly
|
||||
<form action={submitAction}>
|
||||
<input name="name" />
|
||||
<button disabled={isPending}>Save</button>
|
||||
{error && <p>{error}</p>}
|
||||
</form>;
|
||||
```
|
||||
|
||||
### `useOptimistic` (19)
|
||||
|
||||
Shows a speculative value immediately while an async Action is in progress. React automatically reverts to the server-confirmed value when the Action resolves or rejects.
|
||||
|
||||
```tsx
|
||||
const [optimisticName, setOptimisticName] = useOptimistic(currentName);
|
||||
|
||||
const submit = async (formData) => {
|
||||
const newName = formData.get("name");
|
||||
setOptimisticName(newName); // shows immediately
|
||||
await updateName(newName); // reverts if this throws
|
||||
};
|
||||
```
|
||||
|
||||
### `use()` (19)
|
||||
|
||||
Unlike hooks, `use` can appear after conditional statements. Two primary uses:
|
||||
|
||||
**Reading a promise** (must be stable — from a cache, not created inline):
|
||||
|
||||
```tsx
|
||||
function Comments({ commentsPromise }) {
|
||||
const comments = use(commentsPromise); // suspends until resolved
|
||||
return comments.map((c) => <p key={c.id}>{c.text}</p>);
|
||||
}
|
||||
```
|
||||
|
||||
**Reading context after an early return** (hooks cannot appear after `return`):
|
||||
|
||||
```tsx
|
||||
function Heading({ children }) {
|
||||
if (!children) return null;
|
||||
const theme = use(ThemeContext); // valid here; hooks would not be
|
||||
return <h1 style={{ color: theme.color }}>{children}</h1>;
|
||||
}
|
||||
```
|
||||
|
||||
### `useSyncExternalStore` (18)
|
||||
|
||||
The correct way for libraries (and app code) to subscribe to non-React state. Prevents tearing under concurrent rendering.
|
||||
|
||||
```tsx
|
||||
const value = useSyncExternalStore(
|
||||
store.subscribe, // called when store changes
|
||||
store.getSnapshot, // returns current value (must be stable reference if unchanged)
|
||||
store.getServerSnapshot, // optional: for SSR
|
||||
);
|
||||
```
|
||||
|
||||
## Verifying compiler behavior
|
||||
|
||||
The compiler is a black box unless you inspect its output. When reviewing code in compiled paths, run the compiler on the specific code to see what it actually does. Do not guess — verify.
|
||||
|
||||
**Run the compiler on a code snippet:**
|
||||
|
||||
```sh
|
||||
cd site && node -e "
|
||||
const {transformSync} = require('@babel/core');
|
||||
const code = \`<paste component here>\`;
|
||||
const diagnostics = [];
|
||||
const result = transformSync(code, {
|
||||
plugins: [
|
||||
['@babel/plugin-syntax-typescript', {isTSX: true}],
|
||||
['babel-plugin-react-compiler', {
|
||||
logger: {
|
||||
logEvent(_, event) {
|
||||
if (event.kind === 'CompileError' || event.kind === 'CompileSkip') {
|
||||
diagnostics.push(event.detail?.toString?.()?.substring(0, 200));
|
||||
}
|
||||
},
|
||||
},
|
||||
}],
|
||||
],
|
||||
filename: 'test.tsx',
|
||||
});
|
||||
console.log('Compiled:', result.code.includes('_c('));
|
||||
if (diagnostics.length) console.log('Diagnostics:', diagnostics);
|
||||
console.log(result.code);
|
||||
"
|
||||
```
|
||||
|
||||
**Reading compiled output:**
|
||||
|
||||
- `const $ = _c(N)` — allocates N memoization cache slots.
|
||||
- `if ($[n] !== dep)` — cache invalidation guard. Re-computes when `dep` changes (referential equality).
|
||||
- `if ($[n] === Symbol.for("react.memo_cache_sentinel"))` — one-time initialization. Runs once on first render, cached forever after. This is how the compiler handles expressions with no reactive dependencies.
|
||||
- `_temp` functions — pure callbacks the compiler hoisted out of the component body.
|
||||
|
||||
**Check all compiled files at once:**
|
||||
|
||||
```sh
|
||||
cd site && pnpm run lint:compiler
|
||||
```
|
||||
|
||||
This runs the compiler on every file in the compiled paths and reports CompileError / CompileSkip diagnostics. Zero diagnostics means all functions compiled cleanly.
|
||||
|
||||
**What the compiler catches vs. what it does not:**
|
||||
|
||||
The compiler emits `CompileError` for mutations of props, state, or hook arguments during render, and for `ref.current` access during render. The project's lint pipeline catches these automatically — do not flag them in review.
|
||||
|
||||
The compiler does **not** flag impure function calls during render (`Math.random()`, `Date.now()`, `new Date()`). Instead it silently memoizes them with a sentinel guard, freezing the value after first render. This changes semantics without any diagnostic. Verify suspicious calls by running the compiler and checking for sentinel guards in the output.
|
||||
|
||||
## Pitfalls
|
||||
|
||||
Things that are easy to get wrong even when you know the modern API exists. Check your output against these.
|
||||
|
||||
**Effects run twice in development with StrictMode.** React 18 intentionally mounts → unmounts → remounts every component in dev to surface effects that are not resilient to remounting. This is not a bug. If an effect breaks on the second mount, it is missing a cleanup function. Write `return () => cleanup()` from every effect that sets up a subscription, timer, or external resource.
|
||||
|
||||
**Concurrent rendering can call render multiple times.** The render function (component body) may be called more than once before React commits to the DOM. Side effects (mutations, subscriptions, logging) in the render body will run multiple times. Move them into `useEffect` or event handlers.
|
||||
|
||||
**Do not create promises during render and pass them to `use()`.** A new promise is created every render, causing an infinite suspend-retry loop. Create the promise outside the component (module level), or use a caching library (SWR, React Query, `cache()` from React) to stabilize it.
|
||||
|
||||
**`useOptimistic` reverts automatically — do not fight it.** The optimistic value is a presentation layer only. When the Action settles, React replaces it with the real `currentValue` you passed in. Do not try to sync optimistic state back to your real state; let React handle the revert.
|
||||
|
||||
**`flushSync` opts out of automatic batching.** If third-party code or a browser API (e.g. `ResizeObserver`) calls `setState` and you need synchronous DOM flushing, wrap with `flushSync(() => setState(...))`. This is a last resort; prefer letting React batch.
|
||||
|
||||
**`forwardRef` still works in React 19 but will be deprecated.** Function components accept `ref` as a plain prop now. New code should use the prop directly. Existing `forwardRef` wrappers continue to work without changes; migrate when convenient.
|
||||
|
||||
**`<Activity>` does not unmount.** Content inside a hidden `<Activity>` boundary stays mounted. Effects keep running. Use it for preserving scroll position or form state, not for preventing expensive mounts — use lazy loading for that.
|
||||
|
||||
**TypeScript: implicit returns from ref callbacks are now type errors.** In React 19, returning anything other than a cleanup function (or nothing) from a ref callback is rejected by the TypeScript types. The most common case is arrow-function refs that implicitly return the DOM node:
|
||||
|
||||
```tsx
|
||||
// Error in React 19 types:
|
||||
<div ref={el => (instance = el)} />
|
||||
|
||||
// Fix — use a block body:
|
||||
<div ref={el => { instance = el; }} />
|
||||
```
|
||||
|
||||
**TypeScript: `useRef` now requires an argument.** `useRef<T>()` with no argument is a type error. Pass `undefined` for mutable refs or `null` for DOM refs you initialize on mount: `useRef<T>(undefined)` / `useRef<HTMLDivElement | null>(null)`.
|
||||
|
||||
**`useId` output format changed across versions.** React 18 produced `:r0:`. React 19.1 changed it to `«r0»`. React 19.2 changed it again to `_r0`. Do not parse or depend on the specific format — treat it as an opaque string.
|
||||
|
||||
**`useFormStatus` reads the nearest parent `<form>` with a function `action`.** It does not reflect native HTML form submissions — only React Actions. A submit button that is a sibling of `<form>` (rather than a descendant) will not see the form's status.
|
||||
|
||||
**Context as a provider (`<Context>`) requires React 19; `<Context.Provider>` still works.** Do not use `<Context>` shorthand in a codebase that needs to support React 18. The two forms can coexist during migration.
|
||||
|
||||
**Compiler freezes impure expressions silently.** `Math.random()`, `Date.now()`, `new Date()`, and `window.innerWidth` in a component body all compile without diagnostics. The compiler wraps them in a sentinel guard (`Symbol.for("react.memo_cache_sentinel")`) that runs the expression once and caches the result forever. The value never updates on re-render. Fix: move to a `useState` initializer (`useState(() => Math.random())`), `useEffect`, or event handler.
|
||||
|
||||
**Component granularity affects compiler optimization.** When one pattern in a component causes a `CompileError` (e.g., a necessary `ref.current` read during render), the compiler skips the **entire** component. If the rest of the component would benefit from compilation, extract the non-compilable pattern into a small child component. This keeps the parent compiled.
|
||||
|
||||
**The compiler only memoizes components and hooks.** Standalone utility functions (even expensive ones called during render) are not compiled. If a utility function is truly expensive, it still needs its own caching strategy outside of React (e.g., a module-level cache, `WeakMap`, etc.).
|
||||
|
||||
**Changing memoization can shift `useEffect` firing.** A value that was unstable before compilation may become stable after, causing an effect that depended on it to fire less often. Conversely, future compiler changes may alter memoization granularity. Effects that use memoized values as dependencies should be resilient to these changes — they should be true synchronization effects, not "run this when X changes" hacks.
|
||||
|
||||
## Behavioral changes that affect code
|
||||
|
||||
- **Automatic batching** (18): State updates in `setTimeout`, `Promise.then`, `addEventListener` callbacks, etc. are now batched into a single re-render. Previously only React synthetic event handlers were batched. Code that relied on unbatched updates (reading DOM synchronously after each `setState`) must use `flushSync`.
|
||||
|
||||
- **StrictMode double-invoke** (18): In development, every component is mounted → unmounted → remounted with the previous state. Every effect runs cleanup → setup twice on initial mount. `useMemo` and `useCallback` also double-invoke their functions. Production behavior is unchanged. If a test or component breaks under this, the component had a latent cleanup bug.
|
||||
|
||||
- **StrictMode ref double-invoke** (19): In development, ref callbacks are also invoked twice on mount (attach → detach → attach). Return a cleanup function from the ref callback to handle detach correctly.
|
||||
|
||||
- **StrictMode memoization reuse** (19): During the second pass of double-rendering, `useMemo` and `useCallback` now reuse the cached result from the first pass instead of calling the function again. Components that are already StrictMode-compatible should not notice a difference.
|
||||
|
||||
- **Suspense fallback commits immediately** (19): When a component suspends, React now commits the nearest `<Suspense>` fallback without waiting for sibling trees to finish rendering. After the fallback is shown, React "pre-warms" suspended siblings in the background. This makes fallbacks appear faster but changes the order of rendering work.
|
||||
|
||||
- **Error re-throwing removed** (19): Errors that are not caught by an Error Boundary are now reported to `window.reportError` (not re-thrown). Errors caught by an Error Boundary go to `console.error` once. If your production monitoring relied on the re-thrown error, add handlers to `createRoot`: `createRoot(el, { onUncaughtError, onCaughtError })`.
|
||||
|
||||
- **Transitions in `popstate` are synchronous** (19): Browser back/forward navigation triggers synchronous transition flushing. This ensures the URL and UI update together atomically during history navigation.
|
||||
|
||||
- **`useEffect` from discrete events flushes synchronously** (18): Effects triggered by a click or keydown (discrete events) are now flushed synchronously before the browser paints, consistent with `useLayoutEffect` for those cases.
|
||||
|
||||
- **Hydration mismatches treated as errors** (18 / improved in 19): Text content mismatches between server HTML and client render revert to client rendering up to the nearest `<Suspense>` boundary. React 19 logs a single diff instead of multiple warnings, making mismatches much easier to diagnose.
|
||||
|
||||
- **New JSX transform required** (19): The automatic JSX runtime introduced in 2020 (`react/jsx-runtime`) is now mandatory. The classic transform (which required `import React from 'react'` in every file) is no longer supported. Most toolchains have already shipped the new transform; check your Babel or TypeScript config if you see warnings.
|
||||
|
||||
- **UMD builds removed** (19): React no longer ships UMD bundles. Load via npm and a bundler, or use an ESM CDN (`import React from "https://esm.sh/react@19"`).
|
||||
|
||||
- **React Compiler automatic memoization** (Compiler 1.0): Build-time Babel plugin that inserts memoization into components and hooks. Components that follow the Rules of React are automatically memoized; components that violate them are silently skipped (no build error, no runtime change). The compiler can memoize conditionally and after early returns — things impossible with manual `useMemo`/`useCallback`. Works with React 17+ via `react-compiler-runtime`; best with React 19+. Projects adopt incrementally via path-based Babel overrides, `compilationMode: 'annotation'`, or the `"use memo"` / `"use no memo"` directives. Check the project's Vite/Babel config to know which paths are compiled. Compiled components show a "Memo ✨" badge in React DevTools.
|
||||
@@ -0,0 +1,199 @@
|
||||
# Modern TypeScript (5.0–6.0 RC) — Reference
|
||||
|
||||
Reference for writing idiomatic TypeScript. Covers what changed, what it replaced, and what to reach for. Respect the project's minimum TypeScript version: don't emit features from a version newer than what the project targets. Check `package.json` and `tsconfig.json` before writing code.
|
||||
|
||||
## How modern TypeScript thinks differently
|
||||
|
||||
The 5.x era resolves years of module system ambiguity and cleans house on legacy options. Three themes dominate:
|
||||
|
||||
**Module semantics are explicit.** `--verbatimModuleSyntax` (5.0) makes import/export intent visible in source: type imports must carry `type`, value imports stay. Combined with `--module preserve` or `--moduleResolution bundler`, the compiler now accurately models what bundlers and modern runtimes actually do. `import defer` (5.9) extends the model to deferred evaluation.
|
||||
|
||||
**Resource lifetimes are first-class.** `using` and `await using` (5.2) provide deterministic cleanup without `try/finally`. Any object implementing `Symbol.dispose` participates. `DisposableStack` handles ad-hoc multi-resource cleanup in functions where creating a full class is overkill.
|
||||
|
||||
**Inference is smarter about what it knows.** Inferred type predicates (5.5) let `.filter(x => x !== undefined)` produce `T[]` instead of `(T | undefined)[]` automatically. `NoInfer<T>` (5.4) gives library authors precise control over which parameters drive inference. Narrowing now survives closures after last assignment, constant indexed accesses, and `switch (true)` patterns.
|
||||
|
||||
**TypeScript 6.0 is a transition release toward 7.0** (the Go-native port). It turns years of soft deprecations into errors and changes several defaults. Most impactful: `types` defaults to `[]` (must list `@types` packages explicitly), `rootDir` defaults to `.`, `strict` defaults to `true`, `module` defaults to `esnext`. Projects relying on implicit behavior need explicit config. Check the deprecations section before upgrading.
|
||||
|
||||
## Replace these patterns
|
||||
|
||||
The left column reflects patterns still common before TypeScript 5.x. Write the right column instead. The "Since" column tells you the minimum TypeScript version required.
|
||||
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}() | [\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
|
||||
## New capabilities
|
||||
|
||||
These enable things that weren't practical before. Reach for them in the described situations.
|
||||
|
||||
| What | Since | When to use it |
|
||||
| ----------------------------------------------- | ------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `using` / `await using` declarations | 5.2 | Any resource needing deterministic cleanup (file handles, DB connections, locks, event listeners). Object must implement `Symbol.dispose` / `Symbol.asyncDispose`. |
|
||||
| `DisposableStack` / `AsyncDisposableStack` | 5.2 | Ad-hoc multi-resource cleanup without creating a class. Call `.defer(fn)` right after acquiring each resource. Stack disposes in LIFO order. |
|
||||
| `const` modifier on type parameters | 5.0 | Force `const`-like (literal/readonly tuple) inference at call sites without requiring callers to write `as const`. Constraint must use `readonly` arrays. |
|
||||
| Decorator metadata (`Symbol.metadata`) | 5.2 | Attach and read per-class metadata from decorators via `context.metadata`. Retrieved as `MyClass[Symbol.metadata]`. Requires `Symbol.metadata ??= Symbol(...)` polyfill. |
|
||||
| `NoInfer<T>` utility type | 5.4 | Prevent a parameter from contributing inference candidates for `T`. Use when one argument should be the "source of truth" and others should only be checked against it. |
|
||||
| Inferred type predicates | 5.5 | Filter callbacks that test for `!== null` or `instanceof` now automatically produce a type predicate. `Array.prototype.filter` then narrows the result array type. |
|
||||
| `--isolatedDeclarations` | 5.5 | Require explicit return types on exported declarations. Unlocks parallel declaration emit by external tooling (esbuild, oxc, etc.) without needing a full type-checker pass. |
|
||||
| `${configDir}` in tsconfig paths | 5.5 | Anchor `typeRoots`, `paths`, `outDir`, etc. in a shared base tsconfig to the _consuming_ project's directory, not the shared file's location. |
|
||||
| Always-truthy/nullish check errors | 5.6 | Catches regex literals in `if`, arrow functions as comparators, `?? 100` on non-nullable left side, misplaced parentheses. No API to call; existing bugs now surface as errors. |
|
||||
| Iterator helper methods (`IteratorObject`) | 5.6 | Built-in iterators from `Map`, `Set`, generators, etc. now have `.map()`, `.filter()`, `.take()`, `.drop()`, `.flatMap()`, `.toArray()`, `.reduce()`, etc. Use `Iterator.from(iterable)` to wrap any iterable. |
|
||||
| `--noUncheckedSideEffectImports` | 5.6 | Error when a side-effect import (`import "..."`) resolves to nothing. Catches typos in polyfill or CSS imports. |
|
||||
| `--noCheck` | 5.6 | Skip type checking entirely during emit. Useful for separating "fast emit" from "thorough check" pipeline stages, especially with `--isolatedDeclarations`. |
|
||||
| `--rewriteRelativeImportExtensions` | 5.7 | Rewrite `.ts`→`.js`, `.tsx`→`.jsx`, `.mts`→`.mjs`, `.cts`→`.cjs` in relative imports during emit. Required when writing `.ts` imports for Node.js strip-types mode and still needing `.js` output for library distribution. |
|
||||
| `--erasableSyntaxOnly` | 5.8 | Error on constructs that can't be type-stripped by Node.js `--experimental-strip-types`: `enum`, `namespace` with code, parameter properties, `import =` aliases. |
|
||||
| `require()` of ESM under `--module nodenext` | 5.8 | Node.js 22+ allows CJS to `require()` ESM files (no top-level `await`). TypeScript now allows this under `nodenext` without error. |
|
||||
| `import defer * as ns from "..."` | 5.9 | Defer module _evaluation_ (not loading) until first property access. Module is loaded and verified at import time; side-effects are delayed. Only works with `--module preserve` or `esnext`. |
|
||||
| `Set` algebra methods | 5.5 | Non-mutating: `union`, `intersection`, `difference`, `symmetricDifference` → new `Set`. Predicate: `isSubsetOf`, `isSupersetOf`, `isDisjointFrom` → `boolean`. Requires `esnext` or `es2025` lib. |
|
||||
| `Object.groupBy` / `Map.groupBy` | 5.4 | Group an iterable into buckets by key function. Return type has all keys as optional (not every key is guaranteed present). Requires `esnext` or `es2024`+ lib. |
|
||||
| `Temporal` API types | 6.0 RC | `Temporal.Now`, `Temporal.Instant`, `Temporal.PlainDate`, etc. Available under `esnext` or `esnext.temporal` lib. Usable in runtimes that already ship it (V8 118+, SpiderMonkey, etc.). |
|
||||
| `@satisfies` in JSDoc | 5.0 | Validates that a JS expression satisfies a type without widening it — the TS `satisfies` operator for `.js` files. Write `/** @satisfies {MyType} */` above the declaration or inline on a parenthesized expression. |
|
||||
| `@overload` in JSDoc | 5.0 | Declare multiple call signatures for a JS function. Each JSDoc comment tagged `@overload` is treated as a distinct overload; the final JSDoc comment (without `@overload`) describes the implementation signature. |
|
||||
| Getter/setter with completely unrelated types | 5.1 | `get style(): CSSStyleDeclaration` and `set style(v: string)` can now have fully unrelated types, provided both have explicit type annotations. Previously the getter type was required to be a subtype of the setter type. |
|
||||
| `instanceof` narrowing via `Symbol.hasInstance` | 5.3 | When a class defines `static [Symbol.hasInstance](val: unknown): val is T`, the `instanceof` operator now narrows to the predicate type `T`, not the class type itself. Useful when the runtime check and the structural type differ. |
|
||||
| Regex literal syntax checking | 5.5 | TypeScript validates regex literal syntax: malformed groups, nonexistent backreferences, named capture mismatches, and features not available at the current `--target`. No API needed; existing latent bugs surface as errors automatically. |
|
||||
| `--build` continues past intermediate errors | 5.6 | `tsc --build` no longer stops at the first failing project. All projects are built and errors reported together. Use `--stopOnBuildErrors` to restore the old stop-on-first-error behavior. Useful for monorepos during upgrades. |
|
||||
| `--module node18` | 5.8 | Stable `--module` flag for Node.js 18 semantics: disallows `require()` of ESM (unlike `nodenext`) and still allows import assertions. Use when pinned to Node 18 and not ready for `nodenext` behavior changes. |
|
||||
| `--module node20` | 5.9 | Stable `--module` flag for Node.js 20 semantics: permits `require()` of ESM, rejects import assertions. Implies `--target es2023` (unlike `nodenext`, which floats to `esnext`). |
|
||||
|
||||
## Key APIs
|
||||
|
||||
### `Disposable` / `AsyncDisposable` / stacks (5.2)
|
||||
|
||||
Global types provided by TypeScript's lib (requires `esnext.disposable` or `esnext` in `lib`):
|
||||
|
||||
- `Disposable` — `{ [Symbol.dispose](): void }`
|
||||
- `AsyncDisposable` — `{ [Symbol.asyncDispose](): PromiseLike<void> }`
|
||||
- `DisposableStack` — `defer(fn)`, `use(resource)`, `adopt(value, disposeFn)`, `move()`. Is itself `Disposable`.
|
||||
- `AsyncDisposableStack` — async equivalent. Is itself `AsyncDisposable`.
|
||||
- `SuppressedError` — thrown when both the scope body and a `[Symbol.dispose]` throw. `.error` holds the dispose-phase error; `.suppressed` holds the original error.
|
||||
|
||||
Polyfill the symbols in older runtimes:
|
||||
|
||||
```ts
|
||||
Symbol.dispose ??= Symbol("Symbol.dispose");
|
||||
Symbol.asyncDispose ??= Symbol("Symbol.asyncDispose");
|
||||
```
|
||||
|
||||
### Decorator context types (5.0)
|
||||
|
||||
Each decorator kind receives a typed context object as its second parameter:
|
||||
|
||||
- `ClassDecoratorContext`
|
||||
- `ClassMethodDecoratorContext`
|
||||
- `ClassGetterDecoratorContext`
|
||||
- `ClassSetterDecoratorContext`
|
||||
- `ClassFieldDecoratorContext`
|
||||
- `ClassAccessorDecoratorContext`
|
||||
|
||||
All context objects have `.name`, `.kind`, `.static`, `.private`, and `.metadata`. Method/getter/setter/accessor contexts also have `.addInitializer(fn)` for running code at construction time.
|
||||
|
||||
### `IteratorObject` (5.6)
|
||||
|
||||
`IteratorObject<T, TReturn, TNext>` is the new type for built-in iterable iterators. Key methods: `map`, `filter`, `take`, `drop`, `flatMap`, `forEach`, `reduce`, `some`, `every`, `find`, `toArray`. Not the same as the pre-existing structural `Iterator<T>` protocol.
|
||||
|
||||
- Generators produce `Generator<T>` which extends `IteratorObject`.
|
||||
- `Map.prototype.entries()` returns `MapIterator<[K, V]>`, `Set.prototype.values()` returns `SetIterator<T>`, etc.
|
||||
- `Iterator.from(iterable)` converts any `Iterable` to an `IteratorObject`.
|
||||
- `AsyncIteratorObject` exists for async parity.
|
||||
- `--strictBuiltinIteratorReturn` (new `--strict`-mode flag in 5.6) makes the return type of `BuiltinIteratorReturn` be `undefined` instead of `any`, catching unchecked `done` access.
|
||||
|
||||
### Array copying methods (5.2)
|
||||
|
||||
Declared on `Array`, `ReadonlyArray`, and all `TypedArray` types. Use these instead of the mutating variants when you need to preserve the original:
|
||||
|
||||
| Mutating | Non-mutating copy |
|
||||
| ---------------------------------- | ------------------------------------- |
|
||||
| `arr.sort(cmp)` | `arr.toSorted(cmp)` |
|
||||
| `arr.reverse()` | `arr.toReversed()` |
|
||||
| `arr.splice(start, del, ...items)` | `arr.toSpliced(start, del, ...items)` |
|
||||
| `arr[i] = v` | `arr.with(i, v)` |
|
||||
|
||||
## Pitfalls
|
||||
|
||||
Things easy to get wrong even when you know the modern API exists. Check your output against these.
|
||||
|
||||
**tsconfig defaults changed hard in 6.0.** `types: []` means no `@types/*` packages load implicitly. If you see floods of "cannot find name 'process'" or "cannot find module 'fs'" after upgrading to 6.0, add `"types": ["node"]` (or whatever you need) to `compilerOptions`. `rootDir: "."` means a project with source in `src/` will emit to `dist/src/` instead of `dist/` — add `"rootDir": "./src"` explicitly. `strict: true` by default means projects with loose code see new errors.
|
||||
|
||||
**`using` requires a runtime polyfill on older runtimes.** `Symbol.dispose` and `Symbol.asyncDispose` don't exist before Node.js 18.x / Chrome 120. Add the two-line polyfill at your entry point. `DisposableStack` and `AsyncDisposableStack` need a more substantial polyfill (e.g. from `@microsoft/using-polyfill`).
|
||||
|
||||
**`using` disposes in LIFO order.** Resources declared later in a scope are disposed first. Declare in the order you want reversed cleanup (acquisition order). `DisposableStack.defer` also runs in LIFO order.
|
||||
|
||||
**Inferred type predicates have if-and-only-if semantics.** `x => !!x` does NOT infer `x is NonNullable<T>` because `0`, `""`, and `false` are falsy but not absent. TypeScript correctly refuses the predicate. Use `x => x !== undefined` or `x => x !== null` for precise null/undefined filters. If a predicate isn't being inferred, the false branch is probably ambiguous.
|
||||
|
||||
**`--verbatimModuleSyntax` breaks CJS `require` emit.** Under this flag ESM `import`/`export` is emitted verbatim. You cannot produce `require()` calls from standard `import` syntax. For CJS output you must use `import foo = require("foo")` and `export = { ... }` syntax explicitly.
|
||||
|
||||
**`NoInfer<T>` doesn't prevent `T` from being resolved, only from being contributed at that position.** Other parameters can still infer `T`. It means "don't use me as an inference candidate", not "block `T` from being resolved".
|
||||
|
||||
**`--isolatedDeclarations` requires explicit return types on all exports.** Exported arrow functions, function declarations, and class methods all need annotations if their return type isn't trivially inferrable from a literal or type assertion. Editor quick-fixes can add them automatically.
|
||||
|
||||
**Standard decorators are incompatible with `--experimentalDecorators`.** Different type signatures, metadata model, and emit. A decorator written for one will not work with the other. `--emitDecoratorMetadata` is not supported with standard decorators. Don't mix the two systems in one project.
|
||||
|
||||
**`import defer` does not downlevel.** TypeScript does not transform `import defer` to polyfill-compatible code. The module is still _loaded_ eagerly (must exist); only _evaluation_ is deferred. Only use it under `--module preserve` or `esnext` with a runtime or bundler that supports it.
|
||||
|
||||
**`--erasableSyntaxOnly` prohibits parameter properties.** `constructor(public x: number)` is not allowed. Expand to an explicit field declaration plus assignment in the constructor body.
|
||||
|
||||
**Closure narrowing is invalidated if the variable is assigned anywhere in a nested function.** TypeScript cannot know when a nested function will run, so any assignment to a `let`/param inside a nested function — even a no-op like `value = value` — invalidates narrowing for all closures in the outer scope. Only the outer "no further assignments after this point" pattern is safe.
|
||||
|
||||
**Constant indexed access narrowing requires both `obj` and `key` to be unmodified between the check and the use.** If either is a `let` that could be reassigned, TypeScript will not narrow `obj[key]`. Extract the value to a `const` in that case.
|
||||
|
||||
**`switch (true)` narrowing does not carry across fall-through cases.** In a `switch (true)`, each `case` condition narrows independently. A variable narrowed in `case typeof x === "string":` that falls through to the next case will have its narrowing widened by the next condition, not accumulated from the previous one.
|
||||
|
||||
**`const` type parameter modifier falls back when constraint is mutable.** `<const T extends string[]>(args: T)` falls back to `string[]` because `readonly ["a", "b"]` isn't assignable to `string[]`. Use `<const T extends readonly string[]>` for arrays.
|
||||
|
||||
**`assert` import syntax errors under `--module nodenext` since 5.8.** Any remaining `import x from "..." assert { ... }` must be updated to `import x from "..." with { ... }`.
|
||||
|
||||
**`Array.prototype.filter(x => x !== null)` now narrows to non-null (5.5).** This is almost always correct, but if you intentionally needed the nullable type downstream, add an explicit annotation: `const items: (T | null)[] = arr.filter(x => x !== null)`.
|
||||
|
||||
## Behavioral changes that affect code
|
||||
|
||||
- **All enums are union enums** (5.0): Every enum member gets its own literal type. Out-of-domain literal assignment to an enum type now errors. Cross-enum assignment between enums with identical names but differing values now errors.
|
||||
- **Relational operators no longer allow implicit string/number coercions** (5.0): `ns > 4` where `ns: number | string` is a type error. Use `+ns > 4` to explicitly coerce.
|
||||
- **`--module`/`--moduleResolution` must agree on node flavor** (5.2): Mixing `--module nodenext` with `--moduleResolution bundler` is an error. Use `--module nodenext` alone or `--module esnext --moduleResolution bundler`.
|
||||
- **Deprecations from 5.0 become hard errors in 5.5**: `--importsNotUsedAsValues`, `--preserveValueImports`, `--target ES3`, `--out`, and several others are fully removed in 5.5. They can no longer be specified, even with `"ignoreDeprecations": "5.0"`. Migrate to `--verbatimModuleSyntax` for the import flags.
|
||||
- **Type-only imports conflicting with local values** (5.4): Under `--isolatedModules`, `import { Foo } from "..."` where a local `let Foo` also exists now errors. Use `import type { Foo }` or `import { type Foo }`.
|
||||
- **Reference directives no longer synthesized or preserved in declaration emit** (5.5): `/// <reference types="node" />` TypeScript used to add automatically is no longer emitted. User-written directives are dropped unless they carry `preserve="true"`. Update library `tsconfig.json` if you relied on this.
|
||||
- **`.mts` files never emit CJS; `.cts` files never emit ESM** (5.6): Regardless of `--module` setting. Previously the extension was ignored in some modes.
|
||||
- **JSON imports under `--module nodenext` require `with { type: "json" }`** (5.7): `import data from "./config.json"` without the attribute is now a type error.
|
||||
- **`TypedArray`s are now generic** (5.7): `Uint8Array` is `Uint8Array<TArrayBuffer extends ArrayBufferLike = ArrayBufferLike>`. Code passing `Buffer` (from `@types/node`) to typed-array parameters may see new errors. Update `@types/node` to a version that matches.
|
||||
- **`import assert { ... }` is an error under `--module nodenext`** (5.8): Node.js 22 dropped support for the old syntax. Use `with { ... }`.
|
||||
- **`types` defaults to `[]` in 6.0**: All implicit `@types/*` loading stops. Add an explicit `"types": ["node"]` or the array will remain empty. Using `"types": ["*"]` restores the 5.x behavior.
|
||||
- **`rootDir` defaults to `.` (the tsconfig directory) in 6.0**: Previously inferred from the common ancestor of all source files. Projects with `"include": ["./src"]` and no explicit `rootDir` will now emit into `dist/src/` instead of `dist/`. Add `"rootDir": "./src"` to fix.
|
||||
- **`strict` defaults to `true` in 6.0**: Projects that were implicitly not strict will see new errors. Set `"strict": false` explicitly if you're not ready to fix them.
|
||||
- **`--baseUrl` deprecated in 6.0** and no longer acts as a module resolution root. Add explicit prefixes to your `paths` entries instead.
|
||||
- **`--moduleResolution node` (node10) deprecated in 6.0**: Removed in 7.0. Migrate to `nodenext` or `bundler`.
|
||||
- **`amd`, `umd`, `systemjs`, `none` module targets deprecated in 6.0**: Removed in 7.0. Migrate to a bundler.
|
||||
- **`--outFile` removed in 6.0**: Use a bundler (esbuild, Rollup, Webpack, etc.).
|
||||
- **`module Foo { }` syntax removed in 6.0**: Rename all such declarations to `namespace Foo { }`.
|
||||
- **`--esModuleInterop false` and `--allowSyntheticDefaultImports false` removed in 6.0**: Safe interop is now always on. Default imports from CJS modules (`import express from "express"`) are always valid.
|
||||
- **Explicit `typeRoots` disables upward `node_modules/@types` fallback** (5.1): When `typeRoots` is specified and a lookup fails in those directories, TypeScript no longer walks parent directories for `@types`. If you relied on the fallback, add `"./node_modules/@types"` explicitly to your `typeRoots` array.
|
||||
- **`super.` on instance field properties is a type error** (5.3): Calling `super.foo()` where `foo` is a class field (arrow function assigned in the constructor) rather than a prototype method now errors. Instance fields don't exist on the prototype; `super.field` is `undefined` at runtime.
|
||||
- **`--build` always emits `.tsbuildinfo`** (5.6): Previously only written when `--incremental` or `--composite` was set. Now written unconditionally in any `--build` invocation. Update `.gitignore` or CI artifact management if needed.
|
||||
- **`.mts`/`.cts` extensions and `package.json` `"type"` respected in all module modes** (5.6): Format-specific extensions and the `"type"` field inside `node_modules` are now honored regardless of `--module` setting (except `amd`, `umd`, `system`). A `.mts` file will never emit CJS output even under `--module commonjs`.
|
||||
- **Granular return expression checking** (5.8): Each branch of a conditional expression (`cond ? a : b`) directly inside a `return` statement is now checked individually against the declared return type. Previously an `any`-typed branch could silently suppress type errors in the other branch.
|
||||
@@ -4,7 +4,7 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.25.7"
|
||||
default: "1.25.8"
|
||||
use-cache:
|
||||
description: "Whether to use the cache."
|
||||
default: "true"
|
||||
|
||||
@@ -82,9 +82,6 @@ updates:
|
||||
mui:
|
||||
patterns:
|
||||
- "@mui*"
|
||||
radix:
|
||||
patterns:
|
||||
- "@radix-ui/*"
|
||||
react:
|
||||
patterns:
|
||||
- "react"
|
||||
|
||||
@@ -204,7 +204,7 @@ jobs:
|
||||
|
||||
# Needed for helm chart linting
|
||||
- name: Install helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # v4.3.1
|
||||
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # v5.0.0
|
||||
with:
|
||||
version: v3.9.2
|
||||
continue-on-error: true
|
||||
@@ -870,7 +870,7 @@ jobs:
|
||||
# the check to pass. This is desired in PRs, but not in mainline.
|
||||
- name: Publish to Chromatic (non-mainline)
|
||||
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@07791f8243f4cb2698bf4d00426baf4b2d1cb7e0 # v13.3.5
|
||||
uses: chromaui/action@f191a0224b10e1a38b2091cefb7b7a2337009116 # v16.0.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -902,7 +902,7 @@ jobs:
|
||||
# infinitely "in progress" in mainline unless we re-review each build.
|
||||
- name: Publish to Chromatic (mainline)
|
||||
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@07791f8243f4cb2698bf4d00426baf4b2d1cb7e0 # v13.3.5
|
||||
uses: chromaui/action@f191a0224b10e1a38b2091cefb7b7a2337009116 # v16.0.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
|
||||
@@ -240,6 +240,7 @@ jobs:
|
||||
- name: Create Coder Task for Documentation Check
|
||||
if: steps.check-secrets.outputs.skip != 'true'
|
||||
id: create_task
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/create-task-action
|
||||
with:
|
||||
coder-url: ${{ secrets.DOC_CHECK_CODER_URL }}
|
||||
@@ -254,8 +255,21 @@ jobs:
|
||||
github-issue-url: ${{ steps.determine-context.outputs.pr_url }}
|
||||
comment-on-issue: false
|
||||
|
||||
- name: Handle Task Creation Failure
|
||||
if: steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome != 'success'
|
||||
run: |
|
||||
{
|
||||
echo "## Documentation Check Task"
|
||||
echo ""
|
||||
echo "⚠️ The external Coder task service was unavailable, so this"
|
||||
echo "advisory documentation check did not run."
|
||||
echo ""
|
||||
echo "Maintainers can rerun the workflow or trigger it manually"
|
||||
echo "after the service recovers."
|
||||
} >> "${GITHUB_STEP_SUMMARY}"
|
||||
|
||||
- name: Write Task Info
|
||||
if: steps.check-secrets.outputs.skip != 'true'
|
||||
if: steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
|
||||
env:
|
||||
TASK_CREATED: ${{ steps.create_task.outputs.task-created }}
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
@@ -273,7 +287,7 @@ jobs:
|
||||
} >> "${GITHUB_STEP_SUMMARY}"
|
||||
|
||||
- name: Wait for Task Completion
|
||||
if: steps.check-secrets.outputs.skip != 'true'
|
||||
if: steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
|
||||
id: wait_task
|
||||
env:
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
@@ -363,7 +377,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Fetch Task Logs
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true'
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
|
||||
env:
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
run: |
|
||||
@@ -376,7 +390,7 @@ jobs:
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Cleanup Task
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true'
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
|
||||
env:
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
run: |
|
||||
@@ -390,6 +404,7 @@ jobs:
|
||||
- name: Write Final Summary
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true'
|
||||
env:
|
||||
CREATE_TASK_OUTCOME: ${{ steps.create_task.outcome }}
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
TASK_MESSAGE: ${{ steps.wait_task.outputs.task_message }}
|
||||
RESULT_URI: ${{ steps.wait_task.outputs.result_uri }}
|
||||
@@ -400,10 +415,15 @@ jobs:
|
||||
echo "---"
|
||||
echo "### Result"
|
||||
echo ""
|
||||
echo "**Status:** ${TASK_MESSAGE:-Task completed}"
|
||||
if [[ -n "${RESULT_URI}" ]]; then
|
||||
echo "**Comment:** ${RESULT_URI}"
|
||||
if [[ "${CREATE_TASK_OUTCOME}" == "success" ]]; then
|
||||
echo "**Status:** ${TASK_MESSAGE:-Task completed}"
|
||||
if [[ -n "${RESULT_URI}" ]]; then
|
||||
echo "**Comment:** ${RESULT_URI}"
|
||||
fi
|
||||
echo ""
|
||||
echo "Task \`${TASK_NAME}\` has been cleaned up."
|
||||
else
|
||||
echo "**Status:** Skipped because the external Coder task"
|
||||
echo "service was unavailable."
|
||||
fi
|
||||
echo ""
|
||||
echo "Task \`${TASK_NAME}\` has been cleaned up."
|
||||
} >> "${GITHUB_STEP_SUMMARY}"
|
||||
|
||||
@@ -5,8 +5,6 @@ on:
|
||||
branches:
|
||||
- main
|
||||
- "release/2.[0-9]+"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -29,9 +27,9 @@ jobs:
|
||||
|
||||
- name: Detect next release version
|
||||
id: version
|
||||
# Find the highest release/2.X branch (exact pattern, no suffixes like
|
||||
# release/2.31_hotfix) and derive the next minor version for the release
|
||||
# currently in development on main.
|
||||
# Find the highest release/2.X branch (exact pattern, no suffixes
|
||||
# like release/2.31_hotfix) and derive the next minor version for
|
||||
# the release currently in development on main.
|
||||
run: |
|
||||
LATEST_MINOR=$(git branch -r | grep -E '^\s*origin/release/2\.[0-9]+$' | \
|
||||
sed 's/.*release\/2\.//' | sort -n | tail -1)
|
||||
@@ -40,8 +38,10 @@ jobs:
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
echo "version=2.$((LATEST_MINOR + 1))" >> "$GITHUB_OUTPUT"
|
||||
NEXT="2.$((LATEST_MINOR + 1))"
|
||||
echo "version=$NEXT" >> "$GITHUB_OUTPUT"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Detected next release: $NEXT"
|
||||
|
||||
- name: Sync issues
|
||||
id: sync
|
||||
@@ -51,6 +51,7 @@ jobs:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: sync
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
name: ${{ steps.version.outputs.version }}
|
||||
timeout: 300
|
||||
|
||||
sync-release-branch:
|
||||
@@ -76,6 +77,7 @@ jobs:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: sync
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
name: ${{ steps.version.outputs.version }}
|
||||
timeout: 300
|
||||
|
||||
code-freeze:
|
||||
@@ -106,38 +108,3 @@ jobs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
timeout: 300
|
||||
|
||||
complete:
|
||||
name: Complete Linear release
|
||||
if: github.event_name == 'release'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Extract release version
|
||||
id: version
|
||||
# Strip "v" prefix and patch: "v2.31.0" -> "2.31". Also detect whether
|
||||
# this is a minor release (v*.*.0) — patch releases (v2.31.1, v2.31.2,
|
||||
# ...) are grouped into the same Linear release and must not re-complete
|
||||
# it after it has already shipped.
|
||||
run: |
|
||||
VERSION=$(echo "$TAG" | sed 's/^v//' | cut -d. -f1,2)
|
||||
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.0$ ]]; then
|
||||
echo "is_minor=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "is_minor=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
env:
|
||||
TAG: ${{ github.event.release.tag_name }}
|
||||
|
||||
- name: Complete release
|
||||
id: complete
|
||||
if: steps.version.outputs.is_minor == 'true'
|
||||
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: complete
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
timeout: 300
|
||||
|
||||
@@ -9,6 +9,7 @@ on:
|
||||
options:
|
||||
- mainline
|
||||
- stable
|
||||
- rc
|
||||
release_notes:
|
||||
description: Release notes for the publishing the release. This is required to create a release.
|
||||
dry_run:
|
||||
@@ -119,9 +120,19 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 2.10.2 -> release/2.10
|
||||
# Derive the release branch from the version tag.
|
||||
# Standard: 2.10.2 -> release/2.10
|
||||
# RC: 2.32.0-rc.0 -> release/2.32-rc.0
|
||||
version="$(./scripts/version.sh)"
|
||||
release_branch=release/${version%.*}
|
||||
if [[ "$version" == *-rc.* ]]; then
|
||||
# Extract major.minor and rc suffix from e.g. 2.32.0-rc.0
|
||||
base_version="${version%%-rc.*}" # 2.32.0
|
||||
major_minor="${base_version%.*}" # 2.32
|
||||
rc_suffix="${version##*-rc.}" # 0
|
||||
release_branch="release/${major_minor}-rc.${rc_suffix}"
|
||||
else
|
||||
release_branch=release/${version%.*}
|
||||
fi
|
||||
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
|
||||
if [[ -z "${branch_contains_tag}" ]]; then
|
||||
echo "Ref tag must exist in a branch named ${release_branch} when creating a release, did you use scripts/release.sh?"
|
||||
@@ -531,6 +542,9 @@ jobs:
|
||||
if [[ $CODER_RELEASE_CHANNEL == "stable" ]]; then
|
||||
publish_args+=(--stable)
|
||||
fi
|
||||
if [[ $CODER_RELEASE_CHANNEL == "rc" ]]; then
|
||||
publish_args+=(--rc)
|
||||
fi
|
||||
if [[ $CODER_DRY_RUN == *t* ]]; then
|
||||
publish_args+=(--dry-run)
|
||||
fi
|
||||
@@ -563,6 +577,35 @@ jobs:
|
||||
VERSION: ${{ steps.version.outputs.version }}
|
||||
CREATED_LATEST_TAG: ${{ steps.build_docker.outputs.created_latest_tag }}
|
||||
|
||||
# Mark the Linear release as shipped.
|
||||
- name: Extract Linear release version
|
||||
if: ${{ !inputs.dry_run }}
|
||||
id: linear_version
|
||||
run: |
|
||||
# Skip RC releases — they must not complete the Linear release.
|
||||
if [[ "$VERSION" == *-rc* ]]; then
|
||||
echo "RC release (${VERSION}), skipping Linear release completion."
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
# Strip patch to get the Linear release version (e.g. 2.32.0 -> 2.32).
|
||||
linear_version=$(echo "$VERSION" | cut -d. -f1,2)
|
||||
echo "version=$linear_version" >> "$GITHUB_OUTPUT"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Completing Linear release ${linear_version}"
|
||||
env:
|
||||
VERSION: ${{ steps.version.outputs.version }}
|
||||
|
||||
- name: Complete Linear release
|
||||
if: ${{ !inputs.dry_run && steps.linear_version.outputs.skip != 'true' }}
|
||||
continue-on-error: true
|
||||
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: complete
|
||||
version: ${{ steps.linear_version.outputs.version }}
|
||||
timeout: 300
|
||||
|
||||
- name: Authenticate to Google Cloud
|
||||
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
|
||||
with:
|
||||
@@ -614,7 +657,7 @@ jobs:
|
||||
retention-days: 7
|
||||
|
||||
- name: Send repository-dispatch event
|
||||
if: ${{ !inputs.dry_run }}
|
||||
if: ${{ !inputs.dry_run && inputs.release_channel != 'rc' }}
|
||||
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
|
||||
with:
|
||||
token: ${{ secrets.CDRCI_GITHUB_TOKEN }}
|
||||
@@ -702,7 +745,7 @@ jobs:
|
||||
name: Publish to winget-pkgs
|
||||
runs-on: windows-latest
|
||||
needs: release
|
||||
if: ${{ !inputs.dry_run }}
|
||||
if: ${{ !inputs.dry_run && inputs.release_channel != 'rc' }}
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
|
||||
@@ -46,6 +46,14 @@ jobs:
|
||||
echo " replacement: \"https://github.com/coder/coder/tree/${HEAD_SHA}/\""
|
||||
} >> .github/.linkspector.yml
|
||||
|
||||
# TODO: Remove this workaround once action-linkspector sets
|
||||
# package-manager-cache: false in its internal setup-node step.
|
||||
# See: https://github.com/UmbrellaDocs/action-linkspector/issues/54
|
||||
- name: Enable corepack and create pnpm store
|
||||
run: |
|
||||
corepack enable pnpm
|
||||
mkdir -p "$(pnpm store path --silent)"
|
||||
|
||||
- name: Check Markdown links
|
||||
uses: umbrelladocs/action-linkspector@37c85bcde51b30bf929936502bac6bfb7e8f0a4d # v1.4.1
|
||||
id: markdown-link-check
|
||||
|
||||
@@ -110,6 +110,9 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
- For experimental or unstable API paths, skip public doc generation with
|
||||
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
|
||||
keeps them out of the published API reference until they stabilize.
|
||||
- Experimental chat endpoints in `coderd/exp_chats.go` omit swagger
|
||||
annotations entirely. Do not add `@Summary`, `@Router`, or other
|
||||
swagger comments to handlers in that file.
|
||||
|
||||
### Database Query Naming
|
||||
|
||||
|
||||
@@ -988,6 +988,7 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g
|
||||
|
||||
codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go
|
||||
go generate ./codersdk/workspacesdk/agentconnmock/
|
||||
./scripts/format_go_file.sh "$@"
|
||||
touch "$@"
|
||||
|
||||
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
|
||||
|
||||
+16
-8
@@ -38,6 +38,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/clistat"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
@@ -308,12 +309,13 @@ type agent struct {
|
||||
containerAPI *agentcontainers.API
|
||||
gitAPIOptions []agentgit.Option
|
||||
|
||||
filesAPI *agentfiles.API
|
||||
gitAPI *agentgit.API
|
||||
processAPI *agentproc.API
|
||||
desktopAPI *agentdesktop.API
|
||||
mcpManager *agentmcp.Manager
|
||||
mcpAPI *agentmcp.API
|
||||
filesAPI *agentfiles.API
|
||||
gitAPI *agentgit.API
|
||||
processAPI *agentproc.API
|
||||
desktopAPI *agentdesktop.API
|
||||
mcpManager *agentmcp.Manager
|
||||
mcpAPI *agentmcp.API
|
||||
contextConfigAPI *agentcontextconfig.API
|
||||
|
||||
socketServerEnabled bool
|
||||
socketPath string
|
||||
@@ -396,11 +398,17 @@ func (a *agent) init() {
|
||||
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
|
||||
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
|
||||
desktop := agentdesktop.NewPortableDesktop(
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(), nil,
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp"))
|
||||
a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager)
|
||||
a.contextConfigAPI = agentcontextconfig.NewAPI(func() string {
|
||||
if m := a.manifest.Load(); m != nil {
|
||||
return m.Directory
|
||||
}
|
||||
return ""
|
||||
})
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
a.sshServer,
|
||||
@@ -1358,7 +1366,7 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
|
||||
// lifecycle transition to avoid delaying Ready.
|
||||
// This runs inside the tracked goroutine so it
|
||||
// is properly awaited on shutdown.
|
||||
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, manifest.Directory); mcpErr != nil {
|
||||
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.Config().MCPConfigFiles); mcpErr != nil {
|
||||
a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -8,10 +10,22 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
agentsdk "github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// platformAbsPath constructs an absolute path that is valid
|
||||
// on the current platform. On Windows, paths must include a
|
||||
// drive letter to be considered absolute.
|
||||
func platformAbsPath(parts ...string) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return `C:\` + filepath.Join(parts...)
|
||||
}
|
||||
return "/" + filepath.Join(parts...)
|
||||
}
|
||||
|
||||
// TestReportConnectionEmpty tests that reportConnection() doesn't choke if given an empty IP string, which is what we
|
||||
// send if we cannot get the remote address.
|
||||
func TestReportConnectionEmpty(t *testing.T) {
|
||||
@@ -42,3 +56,41 @@ func TestReportConnectionEmpty(t *testing.T) {
|
||||
require.Equal(t, proto.Connection_DISCONNECT, req1.GetConnection().GetAction())
|
||||
require.Equal(t, "because", req1.GetConnection().GetReason())
|
||||
}
|
||||
|
||||
func TestContextConfigAPI_InitOnce(t *testing.T) {
|
||||
// Not parallel: uses t.Setenv to clear env vars.
|
||||
|
||||
// Clear env vars so defaults are used and the test is
|
||||
// hermetic regardless of the surrounding environment.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
// After the fix, contextConfigAPI is set once in init() and
|
||||
// never reassigned. Config() evaluates lazily via the
|
||||
// manifest, so there is no concurrent write to race with.
|
||||
dir1 := platformAbsPath("dir1")
|
||||
dir2 := platformAbsPath("dir2")
|
||||
|
||||
a := &agent{}
|
||||
a.manifest.Store(&agentsdk.Manifest{Directory: dir1})
|
||||
a.contextConfigAPI = agentcontextconfig.NewAPI(func() string {
|
||||
if m := a.manifest.Load(); m != nil {
|
||||
return m.Directory
|
||||
}
|
||||
return ""
|
||||
})
|
||||
|
||||
cfg1 := a.contextConfigAPI.Config()
|
||||
require.NotEmpty(t, cfg1.MCPConfigFiles)
|
||||
require.Contains(t, cfg1.MCPConfigFiles[0], dir1)
|
||||
|
||||
// Simulate manifest update on reconnection — no field
|
||||
// reassignment needed, the lazy closure picks it up.
|
||||
a.manifest.Store(&agentsdk.Manifest{Directory: dir2})
|
||||
cfg2 := a.contextConfigAPI.Config()
|
||||
require.NotEmpty(t, cfg2.MCPConfigFiles)
|
||||
require.Contains(t, cfg2.MCPConfigFiles[0], dir2)
|
||||
}
|
||||
|
||||
+15
-8
@@ -3007,7 +3007,7 @@ func TestAgent_Speedtest(t *testing.T) {
|
||||
|
||||
func TestAgent_Reconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := testutil.Logger(t)
|
||||
// After the agent is disconnected from a coordinator, it's supposed
|
||||
// to reconnect!
|
||||
@@ -3020,7 +3020,8 @@ func TestAgent_Reconnect(t *testing.T) {
|
||||
logger,
|
||||
agentID,
|
||||
agentsdk.Manifest{
|
||||
DERPMap: derpMap,
|
||||
DERPMap: derpMap,
|
||||
Directory: "/test/workspace",
|
||||
},
|
||||
statsCh,
|
||||
fCoordinator,
|
||||
@@ -3033,13 +3034,19 @@ func TestAgent_Reconnect(t *testing.T) {
|
||||
})
|
||||
defer closer.Close()
|
||||
|
||||
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
require.Equal(t, client.GetNumRefreshTokenCalls(), 1)
|
||||
close(call1.Resps) // hang up
|
||||
// expect reconnect
|
||||
// Each iteration forces the agent to reconnect by closing
|
||||
// the current coordinate call while the tracked HTTP server
|
||||
// goroutine (from connection 1's createTailnet) is still
|
||||
// alive, widening the race window.
|
||||
const reconnections = 5
|
||||
for i := range reconnections {
|
||||
call := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
require.Equal(t, i+1, client.GetNumRefreshTokenCalls())
|
||||
close(call.Resps) // hang up — triggers reconnect
|
||||
}
|
||||
// Verify final reconnect succeeds.
|
||||
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
// Check that the agent refreshes the token when it reconnects.
|
||||
require.Equal(t, client.GetNumRefreshTokenCalls(), 2)
|
||||
require.Equal(t, reconnections+1, client.GetNumRefreshTokenCalls())
|
||||
closer.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package agentcontextconfig
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// Env var names for context configuration. Prefixed with EXP_
|
||||
// to indicate these are experimental and may change.
|
||||
const (
|
||||
EnvInstructionsDirs = "CODER_AGENT_EXP_INSTRUCTIONS_DIRS"
|
||||
EnvInstructionsFile = "CODER_AGENT_EXP_INSTRUCTIONS_FILE"
|
||||
EnvSkillsDirs = "CODER_AGENT_EXP_SKILLS_DIRS"
|
||||
EnvSkillMetaFile = "CODER_AGENT_EXP_SKILL_META_FILE"
|
||||
EnvMCPConfigFiles = "CODER_AGENT_EXP_MCP_CONFIG_FILES"
|
||||
)
|
||||
|
||||
// Defaults are defined in codersdk/workspacesdk so both
|
||||
// the agent and server can reference them without a
|
||||
// cross-layer import.
|
||||
|
||||
// API exposes the resolved context configuration through the
|
||||
// agent's HTTP API.
|
||||
type API struct {
|
||||
workingDir func() string
|
||||
}
|
||||
|
||||
// NewAPI accepts a closure that returns the working directory.
|
||||
// The directory is evaluated lazily on each call to Config(),
|
||||
// so the caller can update it after construction.
|
||||
func NewAPI(workingDir func() string) *API {
|
||||
if workingDir == nil {
|
||||
workingDir = func() string { return "" }
|
||||
}
|
||||
return &API{workingDir: workingDir}
|
||||
}
|
||||
|
||||
// Config reads env vars and resolves paths. Exported for use
|
||||
// by the MCP manager and tests.
|
||||
func Config(workingDir string) workspacesdk.ContextConfigResponse {
|
||||
// TrimSpace all env vars before cmp.Or so that a
|
||||
// whitespace-only value falls through to the default
|
||||
// consistently. ResolvePaths also trims each comma-
|
||||
// separated entry, but without pre-trimming here a
|
||||
// bare " " would bypass cmp.Or and produce nil.
|
||||
instructionsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsDirs)), workspacesdk.DefaultInstructionsDir)
|
||||
instructionsFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsFile)), workspacesdk.DefaultInstructionsFile)
|
||||
skillsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillsDirs)), workspacesdk.DefaultSkillsDir)
|
||||
skillMetaFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillMetaFile)), workspacesdk.DefaultSkillMetaFile)
|
||||
mcpConfigFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvMCPConfigFiles)), workspacesdk.DefaultMCPConfigFile)
|
||||
|
||||
return workspacesdk.ContextConfigResponse{
|
||||
InstructionsDirs: ResolvePaths(instructionsDir, workingDir),
|
||||
InstructionsFile: instructionsFile,
|
||||
SkillsDirs: ResolvePaths(skillsDir, workingDir),
|
||||
SkillMetaFile: skillMetaFile,
|
||||
MCPConfigFiles: ResolvePaths(mcpConfigFile, workingDir),
|
||||
}
|
||||
}
|
||||
|
||||
// Config returns the resolved config for use by other agent
|
||||
// components (e.g. MCP manager).
|
||||
func (api *API) Config() workspacesdk.ContextConfigResponse {
|
||||
return Config(api.workingDir())
|
||||
}
|
||||
|
||||
// Routes returns the HTTP handler for the context config
|
||||
// endpoint.
|
||||
func (api *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/", api.handleGet)
|
||||
return r
|
||||
}
|
||||
|
||||
func (api *API) handleGet(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, api.Config())
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package agentcontextconfig_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
// Clear all env vars so defaults are used.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, workspacesdk.DefaultInstructionsFile, cfg.InstructionsFile)
|
||||
require.Equal(t, workspacesdk.DefaultSkillMetaFile, cfg.SkillMetaFile)
|
||||
// Default instructions dir is "~/.coder" which resolves
|
||||
// to the home directory.
|
||||
require.Equal(t, []string{filepath.Join(fakeHome, ".coder")}, cfg.InstructionsDirs)
|
||||
// Default skills dir is ".agents/skills" (relative),
|
||||
// resolved against the working directory.
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".agents", "skills")}, cfg.SkillsDirs)
|
||||
// Default MCP config file is ".mcp.json" (relative),
|
||||
// resolved against the working directory.
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, cfg.MCPConfigFiles)
|
||||
})
|
||||
|
||||
t.Run("CustomEnvVars", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
optInstructions := platformAbsPath("opt", "instructions")
|
||||
optSkills := platformAbsPath("opt", "skills")
|
||||
optMCP := platformAbsPath("opt", "mcp.json")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, "CUSTOM.md", cfg.InstructionsFile)
|
||||
require.Equal(t, "META.yaml", cfg.SkillMetaFile)
|
||||
require.Equal(t, []string{optInstructions}, cfg.InstructionsDirs)
|
||||
require.Equal(t, []string{optSkills}, cfg.SkillsDirs)
|
||||
require.Equal(t, []string{optMCP}, cfg.MCPConfigFiles)
|
||||
})
|
||||
|
||||
t.Run("WhitespaceInFileNames", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, "CLAUDE.md", cfg.InstructionsFile)
|
||||
})
|
||||
|
||||
t.Run("CommaSeparatedDirs", func(t *testing.T) {
|
||||
a := platformAbsPath("opt", "a")
|
||||
b := platformAbsPath("opt", "b")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, []string{a, b}, cfg.InstructionsDirs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewAPI_LazyDirectory(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
dir := ""
|
||||
api := agentcontextconfig.NewAPI(func() string { return dir })
|
||||
|
||||
// Before directory is set, relative paths resolve to nothing.
|
||||
cfg := api.Config()
|
||||
require.Empty(t, cfg.SkillsDirs)
|
||||
require.Empty(t, cfg.MCPConfigFiles)
|
||||
|
||||
// After setting the directory, Config() picks it up lazily.
|
||||
dir = platformAbsPath("work")
|
||||
cfg = api.Config()
|
||||
require.NotEmpty(t, cfg.SkillsDirs)
|
||||
require.Equal(t, []string{filepath.Join(dir, ".agents", "skills")}, cfg.SkillsDirs)
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package agentcontextconfig
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ResolvePath resolves a single path that may be absolute,
|
||||
// home-relative (~/ or ~), or relative to the given base
|
||||
// directory. Returns an absolute path. Empty input returns empty.
|
||||
func ResolvePath(raw, baseDir string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
switch {
|
||||
case raw == "~":
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
case strings.HasPrefix(raw, "~/"):
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, raw[2:])
|
||||
case filepath.IsAbs(raw):
|
||||
return raw
|
||||
default:
|
||||
if baseDir == "" {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(baseDir, raw)
|
||||
}
|
||||
}
|
||||
|
||||
// ResolvePaths splits a comma-separated list of paths and
|
||||
// resolves each entry independently. Empty entries and entries
|
||||
// that resolve to empty strings are skipped.
|
||||
func ResolvePaths(raw, baseDir string) []string {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if resolved := ResolvePath(p, baseDir); resolved != "" {
|
||||
out = append(out, resolved)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package agentcontextconfig_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
)
|
||||
|
||||
// platformAbsPath constructs an absolute path that is valid
|
||||
// on the current platform. On Windows paths must include a
|
||||
// drive letter to be considered absolute.
|
||||
func platformAbsPath(parts ...string) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return `C:\` + filepath.Join(parts...)
|
||||
}
|
||||
return "/" + filepath.Join(parts...)
|
||||
}
|
||||
|
||||
func TestResolvePath(t *testing.T) { //nolint:tparallel // subtests using t.Setenv cannot be parallel
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, "", agentcontextconfig.ResolvePath("", platformAbsPath("base")))
|
||||
})
|
||||
|
||||
t.Run("WhitespaceOnly", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, "", agentcontextconfig.ResolvePath(" ", platformAbsPath("base")))
|
||||
})
|
||||
|
||||
// Tests that use t.Setenv cannot be parallel.
|
||||
t.Run("TildeAlone", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
got := agentcontextconfig.ResolvePath("~", platformAbsPath("base"))
|
||||
require.Equal(t, fakeHome, got)
|
||||
})
|
||||
|
||||
t.Run("TildeSlashPath", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
got := agentcontextconfig.ResolvePath("~/docs/readme", platformAbsPath("base"))
|
||||
require.Equal(t, filepath.Join(fakeHome, "docs", "readme"), got)
|
||||
})
|
||||
|
||||
t.Run("AbsolutePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := platformAbsPath("etc", "coder")
|
||||
got := agentcontextconfig.ResolvePath(p, platformAbsPath("base"))
|
||||
require.Equal(t, p, got)
|
||||
})
|
||||
|
||||
t.Run("RelativePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
base := platformAbsPath("work")
|
||||
got := agentcontextconfig.ResolvePath("foo/bar", base)
|
||||
require.Equal(t, filepath.Join(base, "foo", "bar"), got)
|
||||
})
|
||||
|
||||
t.Run("RelativePathWithWhitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
base := platformAbsPath("work")
|
||||
got := agentcontextconfig.ResolvePath(" foo/bar ", base)
|
||||
require.Equal(t, filepath.Join(base, "foo", "bar"), got)
|
||||
})
|
||||
|
||||
t.Run("RelativePathWithEmptyBaseDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := agentcontextconfig.ResolvePath(".agents/skills", "")
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolvePath_HomeUnset(t *testing.T) {
|
||||
// Cannot be parallel — modifies HOME env var.
|
||||
t.Setenv("HOME", "")
|
||||
// Also clear USERPROFILE for Windows compatibility.
|
||||
t.Setenv("USERPROFILE", "")
|
||||
|
||||
require.Equal(t, "", agentcontextconfig.ResolvePath("~", platformAbsPath("base")))
|
||||
require.Equal(t, "", agentcontextconfig.ResolvePath("~/docs", platformAbsPath("base")))
|
||||
}
|
||||
|
||||
func TestResolvePaths(t *testing.T) { //nolint:tparallel // subtests using t.Setenv cannot be parallel
|
||||
t.Run("EmptyString", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, agentcontextconfig.ResolvePaths("", platformAbsPath("base")))
|
||||
})
|
||||
|
||||
t.Run("WhitespaceOnly", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, agentcontextconfig.ResolvePaths(" ", platformAbsPath("base")))
|
||||
})
|
||||
|
||||
t.Run("SingleEntry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := platformAbsPath("abs", "path")
|
||||
got := agentcontextconfig.ResolvePaths(p, platformAbsPath("base"))
|
||||
require.Equal(t, []string{p}, got)
|
||||
})
|
||||
|
||||
// Tests that use t.Setenv cannot be parallel.
|
||||
t.Run("MultipleEntries", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
b := platformAbsPath("b")
|
||||
base := platformAbsPath("base")
|
||||
got := agentcontextconfig.ResolvePaths("~/a,"+b+",rel", base)
|
||||
require.Equal(t, []string{
|
||||
filepath.Join(fakeHome, "a"),
|
||||
b,
|
||||
filepath.Join(base, "rel"),
|
||||
}, got)
|
||||
})
|
||||
|
||||
t.Run("TrimsWhitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := platformAbsPath("a")
|
||||
b := platformAbsPath("b")
|
||||
got := agentcontextconfig.ResolvePaths(" "+a+" , "+b+" ", platformAbsPath("base"))
|
||||
require.Equal(t, []string{a, b}, got)
|
||||
})
|
||||
|
||||
t.Run("SkipsEmptyEntries", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := platformAbsPath("a")
|
||||
b := platformAbsPath("b")
|
||||
got := agentcontextconfig.ResolvePaths(a+",,"+b+",", platformAbsPath("base"))
|
||||
require.Equal(t, []string{a, b}, got)
|
||||
})
|
||||
|
||||
t.Run("TrailingComma", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := platformAbsPath("only")
|
||||
got := agentcontextconfig.ResolvePaths(p+",", platformAbsPath("base"))
|
||||
require.Equal(t, []string{p}, got)
|
||||
})
|
||||
|
||||
t.Run("RelativePathSkippedWhenBaseDirEmpty", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
got := agentcontextconfig.ResolvePaths("~/.coder,.agents/skills", "")
|
||||
require.Equal(t, []string{filepath.Join(fakeHome, ".coder")}, got)
|
||||
})
|
||||
}
|
||||
@@ -148,6 +148,11 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
for k, v := range req.Env {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
// Propagate the chat ID so child processes (e.g.
|
||||
// GIT_ASKPASS) can send it back to the server.
|
||||
if chatID != "" {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_CHAT_ID=%s", chatID))
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
|
||||
@@ -211,7 +211,7 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
stderr, err := sess.StderrPipe()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, sess.Shell())
|
||||
require.NoError(t, sess.Start("sh"))
|
||||
|
||||
// The SSH server lazily starts the session. We need to write a command
|
||||
// and read back to ensure the X11 forwarding is started.
|
||||
|
||||
@@ -32,6 +32,7 @@ func (a *agent) apiHandler() http.Handler {
|
||||
r.Mount("/api/v0/processes", a.processAPI.Routes())
|
||||
r.Mount("/api/v0/desktop", a.desktopAPI.Routes())
|
||||
r.Mount("/api/v0/mcp", a.mcpAPI.Routes())
|
||||
r.Mount("/api/v0/context-config", a.contextConfigAPI.Routes())
|
||||
|
||||
if a.devcontainers {
|
||||
r.Mount("/api/v0/containers", a.containerAPI.Routes())
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
@@ -47,6 +52,9 @@ type API struct {
|
||||
logger slog.Logger
|
||||
desktop Desktop
|
||||
clock quartz.Clock
|
||||
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewAPI creates a new desktop streaming API.
|
||||
@@ -66,6 +74,10 @@ func (a *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/vnc", a.handleDesktopVNC)
|
||||
r.Post("/action", a.handleAction)
|
||||
r.Route("/recording", func(r chi.Router) {
|
||||
r.Post("/start", a.handleRecordingStart)
|
||||
r.Post("/stop", a.handleRecordingStop)
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -116,6 +128,9 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
handlerStart := a.clock.Now()
|
||||
|
||||
// Update last desktop action timestamp for idle recording monitor.
|
||||
a.desktop.RecordActivity()
|
||||
|
||||
// Ensure the desktop is running and grab native dimensions.
|
||||
cfg, err := a.desktop.Start(ctx)
|
||||
if err != nil {
|
||||
@@ -480,9 +495,150 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Close shuts down the desktop session if one is running.
|
||||
func (a *API) Close() error {
|
||||
a.closeMu.Lock()
|
||||
if a.closed {
|
||||
a.closeMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
a.closed = true
|
||||
a.closeMu.Unlock()
|
||||
|
||||
return a.desktop.Close()
|
||||
}
|
||||
|
||||
// decodeRecordingRequest decodes and validates a recording request
|
||||
// from the HTTP body, returning the recording ID. Returns false if
|
||||
// the request was invalid and an error response was already written.
|
||||
func (*API) decodeRecordingRequest(rw http.ResponseWriter, r *http.Request) (string, bool) {
|
||||
ctx := r.Context()
|
||||
var req struct {
|
||||
RecordingID string `json:"recording_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to decode request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
if req.RecordingID == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing recording_id.",
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
if _, err := uuid.Parse(req.RecordingID); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid recording_id format.",
|
||||
Detail: "recording_id must be a valid UUID.",
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
return req.RecordingID, true
|
||||
}
|
||||
|
||||
func (a *API) handleRecordingStart(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
recordingID, ok := a.decodeRecordingRequest(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
a.closeMu.Lock()
|
||||
if a.closed {
|
||||
a.closeMu.Unlock()
|
||||
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Desktop API is shutting down.",
|
||||
})
|
||||
return
|
||||
}
|
||||
a.closeMu.Unlock()
|
||||
|
||||
if err := a.desktop.StartRecording(ctx, recordingID); err != nil {
|
||||
if errors.Is(err, ErrDesktopClosed) {
|
||||
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Desktop API is shutting down.",
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start recording.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: "Recording started.",
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
recordingID, ok := a.decodeRecordingRequest(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
a.closeMu.Lock()
|
||||
if a.closed {
|
||||
a.closeMu.Unlock()
|
||||
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Desktop API is shutting down.",
|
||||
})
|
||||
return
|
||||
}
|
||||
a.closeMu.Unlock()
|
||||
|
||||
// Stop recording (idempotent).
|
||||
// Use a context detached from the HTTP request so that if the
|
||||
// connection drops, the recording process can still shut down
|
||||
// gracefully. WithoutCancel preserves request-scoped values.
|
||||
stopCtx, stopCancel := context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
|
||||
defer stopCancel()
|
||||
artifact, err := a.desktop.StopRecording(stopCtx, recordingID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUnknownRecording) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Recording not found.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrRecordingCorrupted) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Recording is corrupted.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to stop recording.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer artifact.Reader.Close()
|
||||
|
||||
if artifact.Size > workspacesdk.MaxRecordingSize {
|
||||
a.logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.Size),
|
||||
slog.F("max_size", workspacesdk.MaxRecordingSize),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "Recording file exceeds maximum allowed size.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "video/mp4")
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(rw, artifact.Reader)
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
// returning an error if the coordinate field is missing.
|
||||
func coordFromAction(action DesktopAction) (x, y int, err error) {
|
||||
|
||||
@@ -4,12 +4,17 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -21,6 +26,16 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Test recording UUIDs used across tests.
|
||||
const (
|
||||
testRecIDDefault = "870e1f02-8118-4300-a37e-4adb0117baf3"
|
||||
testRecIDStartIdempotent = "250a2ffb-a5e5-4c94-9754-4d6a4ab7ba20"
|
||||
testRecIDStopIdempotent = "38f8a378-f98f-4758-a4ae-950b44cf989a"
|
||||
testRecIDConcurrentA = "8dc173eb-23c6-4601-a485-b6dfb2a42c3a"
|
||||
testRecIDConcurrentB = "fea490d4-70f0-4798-a181-29d65ce25ae1"
|
||||
testRecIDRestart = "75173a0d-b018-4e2e-a771-defa3fc6af69"
|
||||
)
|
||||
|
||||
// Ensure fakeDesktop satisfies the Desktop interface at compile time.
|
||||
var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
|
||||
|
||||
@@ -43,6 +58,14 @@ type fakeDesktop struct {
|
||||
lastTyped string
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
|
||||
// Recording tracking (guarded by recMu).
|
||||
recMu sync.Mutex
|
||||
recordings map[string]string // ID → file path
|
||||
stopCalls []string // recording IDs passed to StopRecording
|
||||
recStopCh chan string // optional: signaled when StopRecording is called
|
||||
startCount int // incremented on each new recording start
|
||||
activityCount int // incremented by RecordActivity
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Start(context.Context) (agentdesktop.DisplayConfig, error) {
|
||||
@@ -107,11 +130,140 @@ func (f *fakeDesktop) CursorPosition(context.Context) (x int, y int, err error)
|
||||
return f.cursorPos[0], f.cursorPos[1], nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) StartRecording(_ context.Context, recordingID string) error {
|
||||
f.recMu.Lock()
|
||||
defer f.recMu.Unlock()
|
||||
if f.recordings == nil {
|
||||
f.recordings = make(map[string]string)
|
||||
}
|
||||
if path, ok := f.recordings[recordingID]; ok {
|
||||
// Check if already stopped (file still exists but stop was
|
||||
// called). For the fake, a stopped recording means its ID
|
||||
// appears in stopCalls. In that case, remove the old file
|
||||
// and start fresh.
|
||||
stopped := slices.Contains(f.stopCalls, recordingID)
|
||||
if !stopped {
|
||||
// Active recording - no-op.
|
||||
return nil
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
_ = os.Remove(path)
|
||||
delete(f.recordings, recordingID)
|
||||
}
|
||||
f.startCount++
|
||||
tmpFile, err := os.CreateTemp("", "fake-recording-*.mp4")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = tmpFile.Write([]byte(fmt.Sprintf("fake-mp4-data-%s-%d", recordingID, f.startCount)))
|
||||
_ = tmpFile.Close()
|
||||
f.recordings[recordingID] = tmpFile.Name()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*agentdesktop.RecordingArtifact, error) {
|
||||
f.recMu.Lock()
|
||||
defer f.recMu.Unlock()
|
||||
if f.recordings == nil {
|
||||
return nil, agentdesktop.ErrUnknownRecording
|
||||
}
|
||||
path, ok := f.recordings[recordingID]
|
||||
if !ok {
|
||||
return nil, agentdesktop.ErrUnknownRecording
|
||||
}
|
||||
f.stopCalls = append(f.stopCalls, recordingID)
|
||||
if f.recStopCh != nil {
|
||||
select {
|
||||
case f.recStopCh <- recordingID:
|
||||
default:
|
||||
}
|
||||
}
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) RecordActivity() {
|
||||
f.recMu.Lock()
|
||||
f.activityCount++
|
||||
f.recMu.Unlock()
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Close() error {
|
||||
f.closed = true
|
||||
f.recMu.Lock()
|
||||
defer f.recMu.Unlock()
|
||||
for _, path := range f.recordings {
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// failStartRecordingDesktop wraps fakeDesktop and overrides
|
||||
// StartRecording to always return an error.
|
||||
type failStartRecordingDesktop struct {
|
||||
fakeDesktop
|
||||
startRecordingErr error
|
||||
}
|
||||
|
||||
func (f *failStartRecordingDesktop) StartRecording(_ context.Context, _ string) error {
|
||||
return f.startRecordingErr
|
||||
}
|
||||
|
||||
// corruptedStopDesktop wraps fakeDesktop and overrides
|
||||
// StopRecording to always return ErrRecordingCorrupted.
|
||||
type corruptedStopDesktop struct {
|
||||
fakeDesktop
|
||||
}
|
||||
|
||||
func (*corruptedStopDesktop) StopRecording(_ context.Context, _ string) (*agentdesktop.RecordingArtifact, error) {
|
||||
return nil, agentdesktop.ErrRecordingCorrupted
|
||||
}
|
||||
|
||||
// oversizedFakeDesktop wraps fakeDesktop and expands recording files
|
||||
// beyond MaxRecordingSize when StopRecording is called.
|
||||
type oversizedFakeDesktop struct {
|
||||
fakeDesktop
|
||||
}
|
||||
|
||||
func (f *oversizedFakeDesktop) StopRecording(ctx context.Context, recordingID string) (*agentdesktop.RecordingArtifact, error) {
|
||||
artifact, err := f.fakeDesktop.StopRecording(ctx, recordingID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Close the original reader since we're going to re-open after truncation.
|
||||
artifact.Reader.Close()
|
||||
|
||||
// Look up the path from the fakeDesktop recordings.
|
||||
f.fakeDesktop.recMu.Lock()
|
||||
path := f.fakeDesktop.recordings[recordingID]
|
||||
f.fakeDesktop.recMu.Unlock()
|
||||
|
||||
// Expand the file to exceed the maximum recording size.
|
||||
if err := os.Truncate(path, workspacesdk.MaxRecordingSize+1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Re-open the truncated file.
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: workspacesdk.MaxRecordingSize + 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestHandleDesktopVNC_StartError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -134,6 +286,37 @@ func TestHandleDesktopVNC_StartError(t *testing.T) {
|
||||
assert.Equal(t, "Failed to start desktop session.", resp.Message)
|
||||
}
|
||||
|
||||
func TestHandleAction_CallsRecordActivity(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "left_click",
|
||||
Coordinate: &[2]int{100, 200},
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
fake.recMu.Lock()
|
||||
count := fake.activityCount
|
||||
fake.recMu.Unlock()
|
||||
assert.Equal(t, 1, count, "handleAction should call RecordActivity exactly once")
|
||||
}
|
||||
|
||||
func TestHandleAction_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -574,3 +757,481 @@ func TestHandleAction_CursorPositionReturnsDeclaredCoordinates(t *testing.T) {
|
||||
// Native (960,540) in 1920x1080 should map to declared space in 1280x720.
|
||||
assert.Equal(t, "x=640,y=360", resp.Output)
|
||||
}
|
||||
|
||||
func TestRecordingStartStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": testRecIDDefault})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": testRecIDDefault})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestRecordingStartFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &failStartRecordingDesktop{
|
||||
fakeDesktop: fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
},
|
||||
startRecordingErr: xerrors.New("start recording error"),
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Failed to start recording.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingStartIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start same recording twice - both should succeed.
|
||||
for range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStartIdempotent})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
// Stop once, verify normal response.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": testRecIDStartIdempotent})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestRecordingStopIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop twice - both should succeed with identical data.
|
||||
var bodies [2][]byte
|
||||
for i := range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(recorder, request)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
|
||||
bodies[i] = recorder.Body.Bytes()
|
||||
}
|
||||
assert.Equal(t, bodies[0], bodies[1])
|
||||
}
|
||||
|
||||
func TestRecordingStopInvalidIDFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": "not-a-uuid"})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStopUnknownRecording(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Send a valid UUID that was never started - should reach
|
||||
// StopRecording, get ErrUnknownRecording, and return 404.
|
||||
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording not found.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingStopOversizedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &oversizedFakeDesktop{
|
||||
fakeDesktop: fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording - file exceeds max size, expect 413.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording file exceeds maximum allowed size.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingMultipleSimultaneous(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start two recordings with different IDs.
|
||||
for _, id := range []string{testRecIDConcurrentA, testRecIDConcurrentB} {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": id})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
// Stop both and verify each returns its own data.
|
||||
expected := map[string][]byte{
|
||||
testRecIDConcurrentA: []byte("fake-mp4-data-" + testRecIDConcurrentA + "-1"),
|
||||
testRecIDConcurrentB: []byte("fake-mp4-data-" + testRecIDConcurrentB + "-2"),
|
||||
}
|
||||
for _, id := range []string{testRecIDConcurrentA, testRecIDConcurrentB} {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": id})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, expected[id], rr.Body.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordingStartMalformedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader([]byte("not json")))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStartEmptyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": ""})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStopEmptyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": ""})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStopMalformedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader([]byte("not json")))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Step 1: Start recording.
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": testRecIDRestart})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Step 2: Stop recording (gets first MP4 data).
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": testRecIDRestart})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
firstData := rr.Body.Bytes()
|
||||
require.NotEmpty(t, firstData)
|
||||
|
||||
// Step 3: Start again with the same ID - should succeed
|
||||
// (old file discarded, new recording started).
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Step 4: Stop again - should return NEW MP4 data.
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
secondData := rr.Body.Bytes()
|
||||
require.NotEmpty(t, secondData)
|
||||
|
||||
// The two recordings should have different data because the
|
||||
// fake increments a counter on each fresh start.
|
||||
assert.NotEqual(t, firstData, secondData,
|
||||
"restarted recording should produce different data")
|
||||
}
|
||||
|
||||
func TestRecordingStartAfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Close the API before sending the request.
|
||||
api.Close()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Desktop API is shutting down.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingStartDesktopClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// StartRecording returns ErrDesktopClosed to simulate a race
|
||||
// where the desktop is closed between the API-level check and
|
||||
// the desktop-level StartRecording call.
|
||||
fake := &failStartRecordingDesktop{
|
||||
fakeDesktop: fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
},
|
||||
startRecordingErr: agentdesktop.ErrDesktopClosed,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Desktop API is shutting down.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingStopCorrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &corruptedStopDesktop{
|
||||
fakeDesktop: fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start a recording so the stop has something to find.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop returns ErrRecordingCorrupted.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
var respStop codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&respStop)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording is corrupted.", respStop.Message)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Desktop abstracts a virtual desktop session running inside a workspace.
|
||||
@@ -58,10 +61,52 @@ type Desktop interface {
|
||||
// CursorPosition returns the current cursor coordinates.
|
||||
CursorPosition(ctx context.Context) (x, y int, err error)
|
||||
|
||||
// RecordActivity marks the desktop as having received user
|
||||
// interaction, resetting the idle-recording timer.
|
||||
RecordActivity()
|
||||
|
||||
// StartRecording begins recording the desktop to an MP4 file
|
||||
// using the caller-provided recording ID. Safe to call
|
||||
// repeatedly - active recordings continue unchanged, stopped
|
||||
// recordings are discarded and restarted. Concurrent recordings
|
||||
// are supported.
|
||||
StartRecording(ctx context.Context, recordingID string) error
|
||||
|
||||
// StopRecording finalizes the recording identified by the given
|
||||
// ID. Idempotent - safe to call on an already-stopped recording.
|
||||
// Returns a RecordingArtifact that the caller can stream. The
|
||||
// caller must close the artifact when done. Returns an error if
|
||||
// the recording ID is unknown.
|
||||
StopRecording(ctx context.Context, recordingID string) (*RecordingArtifact, error)
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// ErrUnknownRecording is returned by StopRecording when the
|
||||
// recording ID is not recognized.
|
||||
var ErrUnknownRecording = xerrors.New("unknown recording ID")
|
||||
|
||||
// ErrDesktopClosed is returned when an operation is attempted on a
|
||||
// closed desktop session.
|
||||
var ErrDesktopClosed = xerrors.New("desktop closed")
|
||||
|
||||
// ErrRecordingCorrupted is returned by StopRecording when the
|
||||
// recording process was force-killed and the artifact is likely
|
||||
// incomplete or corrupt.
|
||||
var ErrRecordingCorrupted = xerrors.New("recording corrupted: process was force-killed")
|
||||
|
||||
// RecordingArtifact is a finalized recording returned by StopRecording.
|
||||
// The caller streams the artifact and must call Close when done. The
|
||||
// artifact remains valid even if the same recording ID is restarted
|
||||
// or the desktop is closed while the caller is reading.
|
||||
type RecordingArtifact struct {
|
||||
// Reader is the MP4 content. Callers must close it when done.
|
||||
Reader io.ReadCloser
|
||||
// Size is the byte length of the MP4 content.
|
||||
Size int64
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
type DisplayConfig struct {
|
||||
Width int // native width in pixels
|
||||
|
||||
@@ -3,6 +3,7 @@ package agentdesktop
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// portableDesktopOutput is the JSON output from
|
||||
@@ -49,32 +52,65 @@ type screenshotOutput struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// recordingProcess tracks a single desktop recording subprocess.
|
||||
type recordingProcess struct {
|
||||
cmd *exec.Cmd
|
||||
filePath string
|
||||
stopped bool
|
||||
killed bool // true when the process was SIGKILLed
|
||||
done chan struct{} // closed when cmd.Wait() returns
|
||||
waitErr error // set before done is closed
|
||||
stopOnce sync.Once
|
||||
idleCancel context.CancelFunc // cancels the per-recording idle goroutine
|
||||
idleDone chan struct{} // closed when idle goroutine exits
|
||||
}
|
||||
|
||||
// maxConcurrentRecordings is the maximum number of active (non-stopped)
|
||||
// recordings allowed at once. This prevents resource exhaustion.
|
||||
const maxConcurrentRecordings = 5
|
||||
|
||||
// idleTimeout is the duration of desktop inactivity after which all
|
||||
// active recordings are automatically stopped.
|
||||
const idleTimeout = 10 * time.Minute
|
||||
|
||||
// portableDesktop implements Desktop by shelling out to the
|
||||
// portabledesktop CLI via agentexec.Execer.
|
||||
type portableDesktop struct {
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
scriptBinDir string // coder script bin directory
|
||||
clock quartz.Clock
|
||||
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
recordings map[string]*recordingProcess // guarded by mu
|
||||
lastDesktopActionAt atomic.Int64
|
||||
}
|
||||
|
||||
// NewPortableDesktop creates a Desktop backed by the portabledesktop
|
||||
// CLI binary, using execer to spawn child processes. scriptBinDir is
|
||||
// the coder script bin directory checked for the binary.
|
||||
// the coder script bin directory checked for the binary. If clk is
|
||||
// nil, a real clock is used.
|
||||
func NewPortableDesktop(
|
||||
logger slog.Logger,
|
||||
execer agentexec.Execer,
|
||||
scriptBinDir string,
|
||||
clk quartz.Clock,
|
||||
) Desktop {
|
||||
return &portableDesktop{
|
||||
if clk == nil {
|
||||
clk = quartz.NewReal()
|
||||
}
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
clock: clk,
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
return pd
|
||||
}
|
||||
|
||||
// Start launches the desktop session (idempotent).
|
||||
@@ -83,7 +119,7 @@ func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return DisplayConfig{}, xerrors.New("desktop is closed")
|
||||
return DisplayConfig{}, ErrDesktopClosed
|
||||
}
|
||||
|
||||
if err := p.ensureBinary(ctx); err != nil {
|
||||
@@ -313,23 +349,328 @@ func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err
|
||||
return result.X, result.Y, nil
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
func (p *portableDesktop) Close() error {
|
||||
// StartRecording begins recording the desktop to an MP4 file.
|
||||
// Three-state idempotency: active recordings are no-ops,
|
||||
// completed recordings are discarded and restarted.
|
||||
func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string) error {
|
||||
// Ensure the desktop session is running before acquiring the
|
||||
// recording lock. Start is independently locked and idempotent.
|
||||
if _, err := p.Start(ctx); err != nil {
|
||||
return xerrors.Errorf("ensure desktop session: %w", err)
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return ErrDesktopClosed
|
||||
}
|
||||
|
||||
// Three-state idempotency:
|
||||
// - Active recording → no-op, continue recording.
|
||||
// - Completed recording → discard old file, start fresh.
|
||||
// - Unknown ID → fall through to start a new recording.
|
||||
if rec, ok := p.recordings[recordingID]; ok {
|
||||
if !rec.stopped {
|
||||
select {
|
||||
case <-rec.done:
|
||||
// Process exited unexpectedly; treat as completed
|
||||
// so we fall through to discard the old file and
|
||||
// restart.
|
||||
default:
|
||||
// Active recording - no-op, continue recording.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
p.logger.Warn(ctx, "failed to remove old recording file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, recordingID)
|
||||
}
|
||||
|
||||
// Check concurrent recording limit.
|
||||
if p.lockedActiveRecordingCount() >= maxConcurrentRecordings {
|
||||
return xerrors.Errorf("too many concurrent recordings (max %d)", maxConcurrentRecordings)
|
||||
}
|
||||
|
||||
// GC sweep: remove stopped recordings with stale files.
|
||||
p.lockedCleanStaleRecordings(ctx)
|
||||
|
||||
if err := p.ensureBinary(ctx); err != nil {
|
||||
return xerrors.Errorf("ensure portabledesktop binary: %w", err)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
|
||||
|
||||
// Use a background context so the process outlives the HTTP
|
||||
// request that triggered it.
|
||||
procCtx, procCancel := context.WithCancel(context.Background())
|
||||
|
||||
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
|
||||
cmd := p.execer.CommandContext(procCtx, p.binPath, "record",
|
||||
// The following options are used to speed up the recording when the desktop is idle.
|
||||
// They were taken out of an example in the portabledesktop repo.
|
||||
// There's likely room for improvement to optimize the values.
|
||||
"--idle-speedup", "20",
|
||||
"--idle-min-duration", "0.35",
|
||||
"--idle-noise-tolerance", "-38dB",
|
||||
filePath)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
procCancel()
|
||||
return xerrors.Errorf("start recording process: %w", err)
|
||||
}
|
||||
|
||||
rec := &recordingProcess{
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
rec.waitErr = cmd.Wait()
|
||||
close(rec.done)
|
||||
// avoid a context resource leak by canceling the context
|
||||
procCancel()
|
||||
}()
|
||||
|
||||
p.recordings[recordingID] = rec
|
||||
|
||||
p.logger.Info(ctx, "started desktop recording",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", filePath),
|
||||
slog.F("pid", cmd.Process.Pid),
|
||||
)
|
||||
|
||||
// Record activity so a recording started on an already-idle
|
||||
// desktop does not stop immediately.
|
||||
p.lastDesktopActionAt.Store(p.clock.Now().UnixNano())
|
||||
|
||||
// Spawn a per-recording idle goroutine.
|
||||
idleCtx, idleCancel := context.WithCancel(context.Background())
|
||||
rec.idleCancel = idleCancel
|
||||
rec.idleDone = make(chan struct{})
|
||||
go func() {
|
||||
defer close(rec.idleDone)
|
||||
p.monitorRecordingIdle(idleCtx, rec)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopRecording finalizes the recording. Idempotent - safe to call
|
||||
// on an already-stopped recording. Returns a RecordingArtifact
|
||||
// that the caller can stream. The caller must close the Reader
|
||||
// on the returned artifact to avoid leaking file descriptors.
|
||||
func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string) (*RecordingArtifact, error) {
|
||||
p.mu.Lock()
|
||||
rec, ok := p.recordings[recordingID]
|
||||
if !ok {
|
||||
p.mu.Unlock()
|
||||
return nil, ErrUnknownRecording
|
||||
}
|
||||
|
||||
p.lockedStopRecordingProcess(ctx, rec, false)
|
||||
killed := rec.killed
|
||||
p.mu.Unlock()
|
||||
|
||||
p.logger.Info(ctx, "stopped desktop recording",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
)
|
||||
|
||||
if killed {
|
||||
return nil, ErrRecordingCorrupted
|
||||
}
|
||||
|
||||
// Open the file and return an artifact. Each call opens a fresh
|
||||
// file descriptor so the caller is insulated from restarts and
|
||||
// desktop close.
|
||||
f, err := os.Open(rec.filePath)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("open recording artifact: %w", err)
|
||||
}
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
return nil, xerrors.Errorf("stat recording artifact: %w", err)
|
||||
}
|
||||
return &RecordingArtifact{
|
||||
Reader: f,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// lockedStopRecordingProcess stops a single recording via stopOnce.
|
||||
// It sends SIGINT, waits up to 15 seconds for graceful exit, then
|
||||
// SIGKILLs. When force is true the process is SIGKILLed immediately
|
||||
// without attempting a graceful shutdown. Must be called while p.mu
|
||||
// is held; the lock is held for the full duration so that no
|
||||
// concurrent StopRecording caller can read rec.stopped = true
|
||||
// before the process has finished writing the MP4 file.
|
||||
//
|
||||
//nolint:revive // force flag keeps shared stopOnce/cleanup logic in one place.
|
||||
func (p *portableDesktop) lockedStopRecordingProcess(ctx context.Context, rec *recordingProcess, force bool) {
|
||||
rec.stopOnce.Do(func() {
|
||||
if force {
|
||||
_ = rec.cmd.Process.Kill()
|
||||
rec.killed = true
|
||||
} else {
|
||||
_ = interruptRecordingProcess(rec.cmd.Process)
|
||||
timer := p.clock.NewTimer(15*time.Second, "agentdesktop", "stop_timeout")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-rec.done:
|
||||
case <-ctx.Done():
|
||||
_ = rec.cmd.Process.Kill()
|
||||
rec.killed = true
|
||||
case <-timer.C:
|
||||
_ = rec.cmd.Process.Kill()
|
||||
rec.killed = true
|
||||
}
|
||||
}
|
||||
rec.stopped = true
|
||||
if rec.idleCancel != nil {
|
||||
rec.idleCancel()
|
||||
}
|
||||
})
|
||||
// NOTE: We intentionally do not wait on rec.done here.
|
||||
// If goleak is added to this package's tests, this may
|
||||
// need revisiting to avoid flakes.
|
||||
}
|
||||
|
||||
// lockedActiveRecordingCount returns the number of recordings that
|
||||
// are still actively running. Must be called while p.mu is held.
|
||||
// The max concurrency is low (maxConcurrentRecordings = 5), so a
|
||||
// full scan is cheap and avoids maintaining a separate counter.
|
||||
func (p *portableDesktop) lockedActiveRecordingCount() int {
|
||||
active := 0
|
||||
for _, rec := range p.recordings {
|
||||
if rec.stopped {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-rec.done:
|
||||
default:
|
||||
active++
|
||||
}
|
||||
}
|
||||
return active
|
||||
}
|
||||
|
||||
// lockedCleanStaleRecordings removes stopped recordings whose temp
|
||||
// files are older than one hour. Must be called while p.mu is held.
|
||||
func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
|
||||
for id, rec := range p.recordings {
|
||||
if !rec.stopped {
|
||||
continue
|
||||
}
|
||||
info, err := os.Stat(rec.filePath)
|
||||
if err != nil {
|
||||
// File already removed or inaccessible; drop entry.
|
||||
delete(p.recordings, id)
|
||||
continue
|
||||
}
|
||||
if p.clock.Since(info.ModTime()) > time.Hour {
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
p.logger.Warn(ctx, "failed to remove stale recording file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
func (p *portableDesktop) Close() error {
|
||||
p.mu.Lock()
|
||||
p.closed = true
|
||||
if p.session != nil {
|
||||
p.session.cancel()
|
||||
// Xvnc is a child process — killing it cleans up the X
|
||||
// session.
|
||||
_ = p.session.cmd.Process.Kill()
|
||||
_ = p.session.cmd.Wait()
|
||||
p.session = nil
|
||||
|
||||
// Force-kill all active recordings. The stopOnce inside
|
||||
// lockedStopRecordingProcess makes this safe for
|
||||
// already-stopped recordings.
|
||||
for _, rec := range p.recordings {
|
||||
p.lockedStopRecordingProcess(context.Background(), rec, true)
|
||||
}
|
||||
|
||||
// Snapshot recording file paths and idle goroutine channels
|
||||
// for cleanup, then clear the map.
|
||||
type recEntry struct {
|
||||
id string
|
||||
filePath string
|
||||
idleDone chan struct{}
|
||||
}
|
||||
var allRecs []recEntry
|
||||
for id, rec := range p.recordings {
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
session := p.session
|
||||
p.session = nil
|
||||
p.mu.Unlock()
|
||||
|
||||
// Wait for all per-recording idle goroutines to exit.
|
||||
for _, entry := range allRecs {
|
||||
if entry.idleDone != nil {
|
||||
<-entry.idleDone
|
||||
}
|
||||
}
|
||||
|
||||
// Remove all recording files and wait for the session to
|
||||
// exit with a timeout so a slow filesystem or hung process
|
||||
// cannot block agent shutdown indefinitely.
|
||||
cleanupDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(cleanupDone)
|
||||
for _, entry := range allRecs {
|
||||
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
|
||||
p.logger.Warn(context.Background(), "failed to remove recording file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("file_path", entry.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
session.cancel()
|
||||
if err := session.cmd.Process.Kill(); err != nil {
|
||||
p.logger.Warn(context.Background(), "failed to kill portabledesktop process",
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := session.cmd.Wait(); err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
p.logger.Warn(context.Background(), "portabledesktop process exited with error",
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
timer := p.clock.NewTimer(15*time.Second, "agentdesktop", "close_cleanup_timeout")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-cleanupDone:
|
||||
case <-timer.C:
|
||||
p.logger.Warn(context.Background(), "timed out waiting for close cleanup")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordActivity marks the desktop as having received user
|
||||
// interaction, resetting the idle-recording timer.
|
||||
func (p *portableDesktop) RecordActivity() {
|
||||
p.lastDesktopActionAt.Store(p.clock.Now().UnixNano())
|
||||
}
|
||||
|
||||
// runCmd executes a portabledesktop subcommand and returns combined
|
||||
// output. The caller must have previously called ensureBinary.
|
||||
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
|
||||
@@ -397,3 +738,31 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
|
||||
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
|
||||
}
|
||||
|
||||
// monitorRecordingIdle watches for desktop inactivity and stops the
|
||||
// given recording when the idle timeout is reached.
|
||||
func (p *portableDesktop) monitorRecordingIdle(ctx context.Context, rec *recordingProcess) {
|
||||
timer := p.clock.NewTimer(idleTimeout, "agentdesktop", "recording_idle")
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
lastNano := p.lastDesktopActionAt.Load()
|
||||
lastAction := time.Unix(0, lastNano)
|
||||
elapsed := p.clock.Since(lastAction)
|
||||
if elapsed >= idleTimeout {
|
||||
p.mu.Lock()
|
||||
p.lockedStopRecordingProcess(context.Background(), rec, false)
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
// Activity happened; reset with remaining budget.
|
||||
timer.Reset(idleTimeout-elapsed, "agentdesktop", "recording_idle")
|
||||
case <-rec.done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,13 +9,17 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// recordedExecer implements agentexec.Execer by recording every
|
||||
@@ -86,6 +90,7 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -117,6 +122,7 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -159,6 +165,7 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -184,6 +191,7 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -282,6 +290,7 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
@@ -289,7 +298,6 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds, "expected at least one command")
|
||||
|
||||
// Find at least one recorded command that contains
|
||||
// all expected argument substrings.
|
||||
found := false
|
||||
@@ -367,6 +375,7 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
@@ -423,6 +432,7 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -445,7 +455,7 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
// Subsequent Start must fail.
|
||||
_, err = pd.Start(ctx)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "desktop is closed")
|
||||
assert.Contains(t, err.Error(), "desktop closed")
|
||||
}
|
||||
|
||||
// --- ensureBinary tests ---
|
||||
@@ -539,7 +549,410 @@ func TestEnsureBinary_NotFound(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StartRecording(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds)
|
||||
// Find the record command (not the up command).
|
||||
found := false
|
||||
for _, cmd := range cmds {
|
||||
joined := strings.Join(cmd, " ")
|
||||
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected a record command with the recording ID")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StartRecording_ConcurrentLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
for i := range maxConcurrentRecordings {
|
||||
err := pd.StartRecording(ctx, uuid.New().String())
|
||||
require.NoError(t, err, "recording %d should succeed", i)
|
||||
}
|
||||
|
||||
err := pd.StartRecording(ctx, uuid.New().String())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too many concurrent recordings")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write a dummy MP4 file at the expected path so StopRecording
|
||||
// can open it as an artifact.
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(filePath) })
|
||||
|
||||
artifact, err := pd.StopRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
defer artifact.Reader.Close()
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_UnknownID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
_, err := pd.StopRecording(ctx, uuid.New().String())
|
||||
require.ErrorIs(t, err, ErrUnknownRecording)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
// Ensure that portableDesktop satisfies the Desktop interface at
|
||||
// compile time. This uses the unexported type so it lives in the
|
||||
// internal test package.
|
||||
var _ Desktop = (*portableDesktop)(nil)
|
||||
|
||||
func TestPortableDesktop_IdleTimeout_StopsRecordings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewMock(t)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
|
||||
// Install the trap before StartRecording so it is guaranteed
|
||||
// to catch the idle monitor's NewTimer call regardless of
|
||||
// goroutine scheduling.
|
||||
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
|
||||
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify recording is active.
|
||||
pd.mu.Lock()
|
||||
require.False(t, pd.recordings[recID].stopped)
|
||||
pd.mu.Unlock()
|
||||
|
||||
// Wait for the idle monitor timer to be created and release
|
||||
// it so the monitor enters its select loop.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
// The stop-all path calls lockedStopRecordingProcess which
|
||||
// creates a per-recording 15s stop_timeout timer.
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout to trigger the stop-all.
|
||||
clk.Advance(idleTimeout)
|
||||
|
||||
// Wait for the stop timer to be created, then release it.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// The recording process should now be stopped.
|
||||
require.Eventually(t, func() bool {
|
||||
pd.mu.Lock()
|
||||
defer pd.mu.Unlock()
|
||||
rec, ok := pd.recordings[recID]
|
||||
return ok && rec.stopped
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_IdleTimeout_ActivityResetsTimer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewMock(t)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
|
||||
// Install the trap before StartRecording so it is guaranteed
|
||||
// to catch the idle monitor's NewTimer call regardless of
|
||||
// goroutine scheduling.
|
||||
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
|
||||
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the idle monitor timer to be created.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
// Advance most of the way but not past the timeout.
|
||||
clk.Advance(idleTimeout - time.Minute)
|
||||
|
||||
// Record activity to reset the timer.
|
||||
pd.RecordActivity()
|
||||
|
||||
// Trap the Reset call that the idle monitor makes when it
|
||||
// sees recent activity.
|
||||
resetTrap := clk.Trap().TimerReset("agentdesktop", "recording_idle")
|
||||
|
||||
// Advance past the original idle timeout deadline. The
|
||||
// monitor should see the recent activity and reset instead
|
||||
// of stopping.
|
||||
clk.Advance(time.Minute)
|
||||
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.Close()
|
||||
|
||||
// Recording should still be active because activity was
|
||||
// recorded.
|
||||
pd.mu.Lock()
|
||||
require.False(t, pd.recordings[recID].stopped)
|
||||
pd.mu.Unlock()
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_IdleTimeout_MultipleRecordings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewMock(t)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID1 := uuid.New().String()
|
||||
recID2 := uuid.New().String()
|
||||
|
||||
// Trap idle timer creation for both recordings.
|
||||
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
|
||||
|
||||
err := pd.StartRecording(ctx, recID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for first recording's idle timer.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
err = pd.StartRecording(ctx, recID2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for second recording's idle timer.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
// Trap the stop timers that will be created when idle fires.
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout.
|
||||
clk.Advance(idleTimeout)
|
||||
|
||||
// Wait for both stop timers.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Both recordings should be stopped.
|
||||
require.Eventually(t, func() bool {
|
||||
pd.mu.Lock()
|
||||
defer pd.mu.Unlock()
|
||||
r1, ok1 := pd.recordings[recID1]
|
||||
r2, ok2 := pd.recordings[recID2]
|
||||
return ok1 && r1.stopped && ok2 && r2.stopped
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StartRecording_ReturnsErrDesktopClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
// Start and close the desktop so it's in the closed state.
|
||||
ctx := t.Context()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, pd.Close())
|
||||
|
||||
// StartRecording should now return ErrDesktopClosed.
|
||||
err = pd.StartRecording(ctx, uuid.New().String())
|
||||
require.ErrorIs(t, err, ErrDesktopClosed)
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Start_ReturnsErrDesktopClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: quartz.NewReal(),
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(pd.clock.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, pd.Close())
|
||||
|
||||
_, err = pd.Start(ctx)
|
||||
require.ErrorIs(t, err, ErrDesktopClosed)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package agentdesktop
|
||||
|
||||
import "os"
|
||||
|
||||
// interruptRecordingProcess sends a SIGINT to the recording process
|
||||
// for graceful shutdown. On Unix, os.Interrupt is delivered as
|
||||
// SIGINT which lets the recorder finalize the MP4 container.
|
||||
func interruptRecordingProcess(p *os.Process) error {
|
||||
return p.Signal(os.Interrupt)
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package agentdesktop
|
||||
|
||||
import "os"
|
||||
|
||||
// interruptRecordingProcess kills the recording process directly
|
||||
// because os.Process.Signal(os.Interrupt) is not supported on
|
||||
// Windows and returns an error without delivering a signal.
|
||||
func interruptRecordingProcess(p *os.Process) error {
|
||||
return p.Kill()
|
||||
}
|
||||
+34
-11
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -70,16 +69,40 @@ func NewManager(logger slog.Logger) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// Connect discovers .mcp.json in dir and connects to all
|
||||
// configured servers. Failed servers are logged and skipped.
|
||||
func (m *Manager) Connect(ctx context.Context, dir string) error {
|
||||
path := filepath.Join(dir, ".mcp.json")
|
||||
configs, err := ParseConfig(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
// Connect reads MCP config files at the given absolute paths and
|
||||
// connects to all configured servers. Failed servers are logged
|
||||
// and skipped. Missing config files are silently skipped.
|
||||
func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error {
|
||||
var allConfigs []ServerConfig
|
||||
for _, configPath := range mcpConfigFiles {
|
||||
configs, err := ParseConfig(configPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
continue
|
||||
}
|
||||
m.logger.Warn(ctx, "failed to parse MCP config",
|
||||
slog.F("path", configPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
return xerrors.Errorf("parse mcp config: %w", err)
|
||||
allConfigs = append(allConfigs, configs...)
|
||||
}
|
||||
|
||||
// Deduplicate by server name; first occurrence wins.
|
||||
seen := make(map[string]struct{})
|
||||
deduped := make([]ServerConfig, 0, len(allConfigs))
|
||||
for _, cfg := range allConfigs {
|
||||
if _, ok := seen[cfg.Name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[cfg.Name] = struct{}{}
|
||||
deduped = append(deduped, cfg)
|
||||
}
|
||||
allConfigs = deduped
|
||||
|
||||
if len(allConfigs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect to servers in parallel without holding the
|
||||
@@ -95,7 +118,7 @@ func (m *Manager) Connect(ctx context.Context, dir string) error {
|
||||
connected []connectedServer
|
||||
)
|
||||
var eg errgroup.Group
|
||||
for _, cfg := range configs {
|
||||
for _, cfg := range allConfigs {
|
||||
eg.Go(func() error {
|
||||
c, err := m.connectServer(ctx, cfg)
|
||||
if err != nil {
|
||||
|
||||
+33
-6
@@ -17,6 +17,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
@@ -272,11 +273,14 @@ func workspaceAgent() *serpent.Command {
|
||||
logger.Info(ctx, "agent devcontainer detection not enabled")
|
||||
}
|
||||
|
||||
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
|
||||
reinitCtx, reinitCancel := context.WithCancel(ctx)
|
||||
defer reinitCancel()
|
||||
reinitEvents := agentsdk.WaitForReinitLoop(reinitCtx, logger, client)
|
||||
|
||||
var (
|
||||
lastErr error
|
||||
mustExit bool
|
||||
lastOwnerID uuid.UUID
|
||||
lastErr error
|
||||
mustExit bool
|
||||
)
|
||||
for {
|
||||
prometheusRegistry := prometheus.NewRegistry()
|
||||
@@ -343,9 +347,32 @@ func workspaceAgent() *serpent.Command {
|
||||
case <-ctx.Done():
|
||||
logger.Info(ctx, "agent shutting down", slog.Error(context.Cause(ctx)))
|
||||
mustExit = true
|
||||
case event := <-reinitEvents:
|
||||
logger.Info(ctx, "agent received instruction to reinitialize",
|
||||
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
|
||||
case event, ok := <-reinitEvents:
|
||||
switch {
|
||||
case !ok:
|
||||
// Channel closed — the reinit loop exited
|
||||
// (terminal 409 or context expired). Keep
|
||||
// running the current agent until the parent
|
||||
// context is canceled.
|
||||
logger.Info(ctx, "reinit channel closed, running without reinit capability")
|
||||
reinitEvents = nil
|
||||
<-ctx.Done()
|
||||
mustExit = true
|
||||
case event.OwnerID != uuid.Nil && event.OwnerID == lastOwnerID:
|
||||
// Duplicate reinit for same owner — already
|
||||
// reinitialized. Cancel the reinit loop
|
||||
// goroutine and keep the current agent.
|
||||
logger.Info(ctx, "skipping redundant reinit, owner unchanged",
|
||||
slog.F("owner_id", event.OwnerID))
|
||||
reinitCancel()
|
||||
reinitEvents = nil
|
||||
<-ctx.Done()
|
||||
mustExit = true
|
||||
default:
|
||||
lastOwnerID = event.OwnerID
|
||||
logger.Info(ctx, "agent received instruction to reinitialize",
|
||||
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = agnt.Close()
|
||||
|
||||
+13
-2
@@ -352,8 +352,6 @@ func TestScheduleOverride(t *testing.T) {
|
||||
require.NoError(t, err, "invalid schedule")
|
||||
ownerClient, _, _, ws := setupTestSchedule(t, sched)
|
||||
now := time.Now()
|
||||
// To avoid the likelihood of time-related flakes, only matching up to the hour.
|
||||
expectedDeadline := now.In(loc).Add(10 * time.Hour).Format("2006-01-02T15:")
|
||||
|
||||
// When: we override the stop schedule
|
||||
inv, root := clitest.New(t,
|
||||
@@ -364,6 +362,19 @@ func TestScheduleOverride(t *testing.T) {
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
require.NoError(t, inv.Run())
|
||||
|
||||
// Fetch the workspace to get the actual deadline set by the
|
||||
// server. Computing our own expected deadline from a separately
|
||||
// captured time.Now() is racy: the CLI command calls time.Now()
|
||||
// internally, and with the Asia/Kolkata +05:30 offset the hour
|
||||
// boundary falls at :30 UTC minutes. A small delay between our
|
||||
// time.Now() and the command's is enough to land in different
|
||||
// hours.
|
||||
updated, err := ownerClient.Workspace(context.Background(), ws[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, updated.LatestBuild.Deadline.IsZero(), "deadline should be set after extend")
|
||||
require.WithinDuration(t, now.Add(10*time.Hour), updated.LatestBuild.Deadline.Time, 5*time.Minute)
|
||||
expectedDeadline := updated.LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)
|
||||
|
||||
// Then: the updated schedule should be shown
|
||||
pty.ExpectMatch(ws[0].OwnerName + "/" + ws[0].Name)
|
||||
pty.ExpectMatch(sched.Humanize())
|
||||
|
||||
@@ -165,6 +165,37 @@ func TestSyncCommands_Golden(t *testing.T) {
|
||||
clitest.TestGoldenFile(t, "TestSyncCommands_Golden/want_success", outBuf.Bytes(), nil)
|
||||
})
|
||||
|
||||
t.Run("want_multiple_deps", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
path, cleanup := setupSocketServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
inv, _ := clitest.New(t, "exp", "sync", "want", "test-unit", "dep-1", "dep-2", "dep-3", "--socket-path", path)
|
||||
inv.Stdout = &outBuf
|
||||
inv.Stderr = &outBuf
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all dependencies were registered by checking status.
|
||||
outBuf.Reset()
|
||||
inv, _ = clitest.New(t, "exp", "sync", "status", "test-unit", "--socket-path", path, "--output", "json")
|
||||
inv.Stdout = &outBuf
|
||||
inv.Stderr = &outBuf
|
||||
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// The output should mention all three dependencies.
|
||||
output := outBuf.String()
|
||||
require.Contains(t, output, "dep-1")
|
||||
require.Contains(t, output, "dep-2")
|
||||
require.Contains(t, output, "dep-3")
|
||||
})
|
||||
|
||||
t.Run("complete", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
path, cleanup := setupSocketServer(t)
|
||||
|
||||
+9
-8
@@ -11,17 +11,16 @@ import (
|
||||
|
||||
func (*RootCmd) syncWant(socketPath *string) *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "want <unit> <depends-on>",
|
||||
Short: "Declare that a unit depends on another unit completing before it can start",
|
||||
Long: "Declare that a unit depends on another unit completing before it can start. The unit specified first will not start until the second has signaled that it has completed.",
|
||||
Use: "want <unit> <depends-on> [depends-on...]",
|
||||
Short: "Declare that a unit depends on other units completing before it can start",
|
||||
Long: "Declare that a unit depends on one or more other units completing before it can start. The unit specified first will not start until all subsequent units have signaled that they have completed.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := i.Context()
|
||||
|
||||
if len(i.Args) != 2 {
|
||||
return xerrors.New("exactly two arguments are required: unit and depends-on")
|
||||
if len(i.Args) < 2 {
|
||||
return xerrors.New("at least two arguments are required: unit and one or more depends-on")
|
||||
}
|
||||
dependentUnit := unit.ID(i.Args[0])
|
||||
dependsOn := unit.ID(i.Args[1])
|
||||
|
||||
opts := []agentsocket.Option{}
|
||||
if *socketPath != "" {
|
||||
@@ -34,8 +33,10 @@ func (*RootCmd) syncWant(socketPath *string) *serpent.Command {
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
if err := client.SyncWant(ctx, dependentUnit, dependsOn); err != nil {
|
||||
return xerrors.Errorf("declare dependency failed: %w", err)
|
||||
for _, dep := range i.Args[1:] {
|
||||
if err := client.SyncWant(ctx, dependentUnit, unit.ID(dep)); err != nil {
|
||||
return xerrors.Errorf("declare dependency failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
cliui.Info(i.Stdout, "Success")
|
||||
|
||||
+1
-1
@@ -16,7 +16,7 @@ SUBCOMMANDS:
|
||||
ping Test agent socket connectivity and health
|
||||
start Wait until all unit dependencies are satisfied
|
||||
status Show unit status and dependency state
|
||||
want Declare that a unit depends on another unit completing before it
|
||||
want Declare that a unit depends on other units completing before it
|
||||
can start
|
||||
|
||||
OPTIONS:
|
||||
|
||||
+5
-5
@@ -1,13 +1,13 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder exp sync want <unit> <depends-on>
|
||||
coder exp sync want <unit> <depends-on> [depends-on...]
|
||||
|
||||
Declare that a unit depends on another unit completing before it can start
|
||||
Declare that a unit depends on other units completing before it can start
|
||||
|
||||
Declare that a unit depends on another unit completing before it can start.
|
||||
The unit specified first will not start until the second has signaled that it
|
||||
has completed.
|
||||
Declare that a unit depends on one or more other units completing before it
|
||||
can start. The unit specified first will not start until all subsequent units
|
||||
have signaled that they have completed.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
+7
-2
@@ -857,13 +857,18 @@ aibridgeproxy:
|
||||
# Comma-separated list of AI provider domains for which HTTPS traffic will be
|
||||
# decrypted and routed through AI Bridge. Requests to other domains will be
|
||||
# tunneled directly without decryption. Supported domains: api.anthropic.com,
|
||||
# api.openai.com, api.individual.githubcopilot.com.
|
||||
# (default: api.anthropic.com,api.openai.com,api.individual.githubcopilot.com,
|
||||
# api.openai.com, api.individual.githubcopilot.com,
|
||||
# api.business.githubcopilot.com, api.enterprise.githubcopilot.com, chatgpt.com.
|
||||
# (default:
|
||||
# api.anthropic.com,api.openai.com,api.individual.githubcopilot.com,api.business.githubcopilot.com,api.enterprise.githubcopilot.com,chatgpt.com,
|
||||
# type: string-array)
|
||||
domain_allowlist:
|
||||
- api.anthropic.com
|
||||
- api.openai.com
|
||||
- api.individual.githubcopilot.com
|
||||
- api.business.githubcopilot.com
|
||||
- api.enterprise.githubcopilot.com
|
||||
- chatgpt.com
|
||||
# URL of an upstream HTTP proxy to chain tunneled (non-allowlisted) requests
|
||||
# through. Format: http://[user:pass@]host:port or https://[user:pass@]host:port.
|
||||
# (default: <unset>, type: string)
|
||||
|
||||
@@ -20,6 +20,21 @@ const HeaderCoderToken = "X-Coder-AI-Governance-Token" //nolint:gosec // This is
|
||||
// request forwarded to aibridged for cross-service log correlation.
|
||||
const HeaderCoderRequestID = "X-Coder-AI-Governance-Request-Id"
|
||||
|
||||
// Copilot provider.
|
||||
const (
|
||||
ProviderCopilotBusiness = "copilot-business"
|
||||
HostCopilotBusiness = "api.business.githubcopilot.com"
|
||||
ProviderCopilotEnterprise = "copilot-enterprise"
|
||||
HostCopilotEnterprise = "api.enterprise.githubcopilot.com"
|
||||
)
|
||||
|
||||
// ChatGPT provider.
|
||||
const (
|
||||
ProviderChatGPT = "chatgpt"
|
||||
HostChatGPT = "chatgpt.com"
|
||||
BaseURLChatGPT = "https://" + HostChatGPT + "/backend-api/codex"
|
||||
)
|
||||
|
||||
// IsBYOK reports whether the request is using BYOK mode, determined
|
||||
// by the presence of the X-Coder-AI-Governance-Token header.
|
||||
func IsBYOK(header http.Header) bool {
|
||||
|
||||
Generated
+24
-2
@@ -10205,12 +10205,26 @@ const docTemplate = `{
|
||||
],
|
||||
"summary": "Get workspace agent reinitialization",
|
||||
"operationId": "get-workspace-agent-reinitialization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Opt in to durable reinit checks",
|
||||
"name": "wait",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
|
||||
}
|
||||
},
|
||||
"409": {
|
||||
"description": "Conflict",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
@@ -12647,11 +12661,16 @@ const docTemplate = `{
|
||||
"agentsdk.ReinitializationEvent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"reason": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationReason"
|
||||
},
|
||||
"workspaceID": {
|
||||
"type": "string"
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -12894,6 +12913,9 @@ const docTemplate = `{
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
|
||||
Generated
+24
-2
@@ -9038,12 +9038,26 @@
|
||||
"tags": ["Agents"],
|
||||
"summary": "Get workspace agent reinitialization",
|
||||
"operationId": "get-workspace-agent-reinitialization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Opt in to durable reinit checks",
|
||||
"name": "wait",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
|
||||
}
|
||||
},
|
||||
"409": {
|
||||
"description": "Conflict",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
@@ -11229,11 +11243,16 @@
|
||||
"agentsdk.ReinitializationEvent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"reason": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationReason"
|
||||
},
|
||||
"workspaceID": {
|
||||
"type": "string"
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -11472,6 +11491,9 @@
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
|
||||
@@ -582,5 +582,20 @@ func (api *API) createAPIKey(ctx context.Context, params apikey.CreateParams) (*
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
// MaxAge is set so the browser persists the cookie to disk rather
|
||||
// than keeping it in memory as a session cookie. Standalone PWAs
|
||||
// (display: standalone) run in their own browser process, and
|
||||
// mobile OSes kill that process when the app is swiped away —
|
||||
// deleting in-memory cookies and forcing an unexpected login.
|
||||
//
|
||||
// We use a long static value (1 year) instead of the key's
|
||||
// LifetimeSeconds because the server refreshes the key's
|
||||
// ExpiresAt on activity but does not re-set the cookie. Tying
|
||||
// MaxAge to the key lifetime would cause the cookie to expire
|
||||
// client-side even when the server-side key is still valid.
|
||||
//
|
||||
// Security is not affected: the server validates ExpiresAt on
|
||||
// every request regardless of the cookie's MaxAge.
|
||||
MaxAge: int((365 * 24 * time.Hour).Seconds()),
|
||||
}), &newkey, nil
|
||||
}
|
||||
|
||||
@@ -394,6 +394,55 @@ func TestSessionExpiry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionCookieMaxAge verifies that the session cookie is a persistent
|
||||
// cookie (has MaxAge set) rather than a session cookie. Standalone PWAs
|
||||
// run in their own browser process and mobile OSes purge in-memory
|
||||
// (session) cookies when that process is killed, so the cookie must be
|
||||
// persisted to disk.
|
||||
func TestSessionCookieMaxAge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
|
||||
// Create the first user (password-based login).
|
||||
req := codersdk.CreateFirstUserRequest{
|
||||
Email: "testuser@coder.com",
|
||||
Username: "testuser",
|
||||
Password: "SomeSecurePassword!",
|
||||
}
|
||||
_, err := client.CreateFirstUser(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Login via the raw HTTP endpoint so we can inspect the Set-Cookie header.
|
||||
loginURL, err := client.URL.Parse("/api/v2/users/login")
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.Request(ctx, http.MethodPost, loginURL.String(), codersdk.LoginWithPasswordRequest{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusCreated, res.StatusCode)
|
||||
|
||||
oneYear := int((365 * 24 * time.Hour).Seconds())
|
||||
var found bool
|
||||
for _, cookie := range res.Cookies() {
|
||||
if cookie.Name == codersdk.SessionTokenCookie {
|
||||
// MaxAge should be set to a long value so the browser
|
||||
// persists the cookie to disk. The server handles real
|
||||
// expiry via the API key's ExpiresAt field.
|
||||
require.Equal(t, oneYear, cookie.MaxAge,
|
||||
"Session cookie MaxAge should be set to 1 year for disk persistence")
|
||||
found = true
|
||||
}
|
||||
}
|
||||
require.True(t, found, "session cookie should be present in login response")
|
||||
}
|
||||
|
||||
func TestAPIKey_OK(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+1
-1
@@ -220,7 +220,7 @@ func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) {
|
||||
Type: string(v.Object.ResourceType),
|
||||
AnyOrgOwner: v.Object.AnyOrgOwner,
|
||||
}
|
||||
if obj.Owner == "me" {
|
||||
if obj.Owner == codersdk.Me {
|
||||
obj.Owner = auth.ID
|
||||
}
|
||||
|
||||
|
||||
+8
-1
@@ -782,7 +782,7 @@ func New(options *Options) *API {
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
ProviderAPIKeys: ChatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
@@ -1221,6 +1221,13 @@ func New(options *Options) *API {
|
||||
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
|
||||
})
|
||||
})
|
||||
r.Route("/user-provider-configs", func(r chi.Router) {
|
||||
r.Get("/", api.listUserChatProviderConfigs)
|
||||
r.Route("/{providerConfig}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertUserChatProviderKey)
|
||||
r.Delete("/", api.deleteUserChatProviderKey)
|
||||
})
|
||||
})
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
|
||||
@@ -10,6 +10,7 @@ const (
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers
|
||||
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
|
||||
@@ -32,4 +33,5 @@ const (
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys
|
||||
)
|
||||
|
||||
@@ -999,15 +999,16 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
|
||||
return sdkToolUsages[i].CreatedAt.Before(sdkToolUsages[j].CreatedAt)
|
||||
})
|
||||
intc := codersdk.AIBridgeInterception{
|
||||
ID: interception.ID,
|
||||
Initiator: MinimalUserFromVisibleUser(initiator),
|
||||
Provider: interception.Provider,
|
||||
Model: interception.Model,
|
||||
Metadata: jsonOrEmptyMap(interception.Metadata),
|
||||
StartedAt: interception.StartedAt,
|
||||
TokenUsages: sdkTokenUsages,
|
||||
UserPrompts: sdkUserPrompts,
|
||||
ToolUsages: sdkToolUsages,
|
||||
ID: interception.ID,
|
||||
Initiator: MinimalUserFromVisibleUser(initiator),
|
||||
Provider: interception.Provider,
|
||||
ProviderName: interception.ProviderName,
|
||||
Model: interception.Model,
|
||||
Metadata: jsonOrEmptyMap(interception.Metadata),
|
||||
StartedAt: interception.StartedAt,
|
||||
TokenUsages: sdkTokenUsages,
|
||||
UserPrompts: sdkUserPrompts,
|
||||
ToolUsages: sdkToolUsages,
|
||||
}
|
||||
if interception.APIKeyID.Valid {
|
||||
intc.APIKeyID = &interception.APIKeyID.String
|
||||
@@ -1572,6 +1573,17 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
convertedDiffStatus := ChatDiffStatus(c.ID, diffStatus)
|
||||
chat.DiffStatus = &convertedDiffStatus
|
||||
}
|
||||
if c.LastInjectedContext.Valid {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
// Internal fields are stripped at write time in
|
||||
// chatd.updateLastInjectedContext, so no
|
||||
// StripInternal call is needed here. Unmarshal
|
||||
// errors are suppressed — the column is written by
|
||||
// us with a known schema.
|
||||
if err := json.Unmarshal(c.LastInjectedContext.RawMessage, &parts); err == nil {
|
||||
chat.LastInjectedContext = parts
|
||||
}
|
||||
}
|
||||
return chat
|
||||
}
|
||||
|
||||
|
||||
@@ -541,6 +541,13 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
PinOrder: 1,
|
||||
MCPServerIDs: []uuid.UUID{uuid.New()},
|
||||
Labels: database.StringMap{"env": "prod"},
|
||||
LastInjectedContext: pqtype.NullRawMessage{
|
||||
// Use a context-file part to verify internal
|
||||
// fields are not present (they are stripped at
|
||||
// write time by chatd, not at read time).
|
||||
RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
// Only ChatID is needed here. This test checks that
|
||||
// Chat.DiffStatus is non-nil, not that every DiffStatus
|
||||
|
||||
@@ -1570,13 +1570,13 @@ func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UU
|
||||
return q.db.AllUserIDs(ctx, includeSystem)
|
||||
}
|
||||
|
||||
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ArchiveChatByID(ctx, id)
|
||||
}
|
||||
@@ -2137,6 +2137,17 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat
|
||||
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
@@ -2811,7 +2822,15 @@ func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (q *querier) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
// Any user who can read chat resources can read the default
|
||||
// model config, since model resolution is required to create
|
||||
// a chat. This avoids gating on ResourceDeploymentConfig
|
||||
// which regular members lack.
|
||||
act, ok := ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return database.ChatModelConfig{}, ErrNoActor
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(act.ID)); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetDefaultChatModelConfig(ctx)
|
||||
@@ -4016,6 +4035,17 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
|
||||
return q.db.GetUserChatCustomPrompt(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetUserChatProviderKeys(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
|
||||
return 0, err
|
||||
@@ -5641,13 +5671,13 @@ func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return q.db.UnarchiveChatByID(ctx, id)
|
||||
}
|
||||
@@ -5752,6 +5782,17 @@ func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateC
|
||||
return q.db.UpdateChatLabelsByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatLastInjectedContext(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
@@ -6435,6 +6476,17 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U
|
||||
return q.db.UpdateUserChatCustomPrompt(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return q.db.UpdateUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
|
||||
return deleteQ(q.log, q.auth, q.db.GetUserByID, q.db.UpdateUserDeletedByID)(ctx, id)
|
||||
}
|
||||
@@ -7162,6 +7214,17 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return q.db.UpsertTemplateUsageStats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return q.db.UpsertUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
|
||||
@@ -392,14 +392,14 @@ func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("ArchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("UnarchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -631,7 +631,7 @@ func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("GetDefaultChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetDefaultChatModelConfig(gomock.Any()).Return(config, nil).AnyTimes()
|
||||
check.Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
check.Asserts(rbac.ResourceChat.WithOwner(testActorID.String()), policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
@@ -721,7 +721,9 @@ func (s *MethodTestSuite) TestChats() {
|
||||
check.Args(threshold).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(chats)
|
||||
}))
|
||||
s.Run("InsertChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatParams{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{OwnerID: arg.OwnerID})
|
||||
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
|
||||
@@ -1204,6 +1206,19 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatMCPServerIDs(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatLastInjectedContext", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatLastInjectedContextParams{
|
||||
ID: chat.ID,
|
||||
LastInjectedContext: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`[{"type":"text","text":"test"}]`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatLastReadMessageID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatLastReadMessageIDParams{
|
||||
@@ -2392,6 +2407,36 @@ func (s *MethodTestSuite) TestUser() {
|
||||
dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt")
|
||||
}))
|
||||
s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key})
|
||||
}))
|
||||
s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns()
|
||||
}))
|
||||
s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"}
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"}
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"}
|
||||
|
||||
@@ -1591,6 +1591,7 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
APIKeyID: seed.APIKeyID,
|
||||
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
|
||||
Provider: takeFirst(seed.Provider, "provider"),
|
||||
ProviderName: takeFirst(seed.ProviderName, "provider-name"),
|
||||
Model: takeFirst(seed.Model, "model"),
|
||||
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
|
||||
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
|
||||
|
||||
@@ -160,12 +160,12 @@ func (m queryMetricsStore) AllUserIDs(ctx context.Context, includeSystem bool) (
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.ArchiveChatByID(ctx, id)
|
||||
r0, r1 := m.s.ArchiveChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("ArchiveChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ArchiveChatByID").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, arg database.ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) {
|
||||
@@ -696,6 +696,14 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
@@ -2528,6 +2536,14 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
|
||||
@@ -4024,12 +4040,12 @@ func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXact
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.UnarchiveChatByID(ctx, id)
|
||||
r0, r1 := m.s.UnarchiveChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UnarchiveChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnarchiveChatByID").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnarchiveTemplateVersion(ctx context.Context, arg database.UnarchiveTemplateVersionParams) error {
|
||||
@@ -4112,6 +4128,14 @@ func (m queryMetricsStore) UpdateChatLabelsByID(ctx context.Context, arg databas
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatLastInjectedContext(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatLastInjectedContext").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLastInjectedContext").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatLastModelConfigByID(ctx, arg)
|
||||
@@ -4552,6 +4576,14 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateUserDeletedByID(ctx, id)
|
||||
@@ -5144,6 +5176,14 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg)
|
||||
|
||||
@@ -148,11 +148,12 @@ func (mr *MockStoreMockRecorder) AllUserIDs(ctx, includeSystem any) *gomock.Call
|
||||
}
|
||||
|
||||
// ArchiveChatByID mocks base method.
|
||||
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ArchiveChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ArchiveChatByID indicates an expected call of ArchiveChatByID.
|
||||
@@ -1170,6 +1171,20 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4728,6 +4743,21 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys mocks base method.
|
||||
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod mocks base method.
|
||||
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7632,11 +7662,12 @@ func (mr *MockStoreMockRecorder) TryAcquireLock(ctx, pgTryAdvisoryXactLock any)
|
||||
}
|
||||
|
||||
// UnarchiveChatByID mocks base method.
|
||||
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnarchiveChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UnarchiveChatByID indicates an expected call of UnarchiveChatByID.
|
||||
@@ -7790,6 +7821,21 @@ func (mr *MockStoreMockRecorder) UpdateChatLabelsByID(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLabelsByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLabelsByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLastInjectedContext mocks base method.
|
||||
func (m *MockStore) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatLastInjectedContext", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatLastInjectedContext indicates an expected call of UpdateChatLastInjectedContext.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatLastInjectedContext(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLastInjectedContext", reflect.TypeOf((*MockStore)(nil).UpdateChatLastInjectedContext), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLastModelConfigByID mocks base method.
|
||||
func (m *MockStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8588,6 +8634,21 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserDeletedByID mocks base method.
|
||||
func (m *MockStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9654,6 +9715,21 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx)
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertWebpushVAPIDKeys mocks base method.
|
||||
func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+42
-4
@@ -1100,7 +1100,8 @@ CREATE TABLE aibridge_interceptions (
|
||||
thread_parent_id uuid,
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256),
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL,
|
||||
provider_name text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1115,6 +1116,8 @@ COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID su
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception''s own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
@@ -1338,7 +1341,11 @@ CREATE TABLE chat_providers (
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
base_url text DEFAULT ''::text NOT NULL,
|
||||
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text])))
|
||||
central_api_key_enabled boolean DEFAULT true NOT NULL,
|
||||
allow_user_api_key boolean DEFAULT false NOT NULL,
|
||||
allow_central_api_key_fallback boolean DEFAULT false NOT NULL,
|
||||
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))),
|
||||
CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key))))
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
|
||||
@@ -1403,7 +1410,8 @@ CREATE TABLE chats (
|
||||
build_id uuid,
|
||||
agent_id uuid,
|
||||
pin_order integer DEFAULT 0 NOT NULL,
|
||||
last_read_message_id bigint
|
||||
last_read_message_id bigint,
|
||||
last_injected_context jsonb
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -2748,6 +2756,17 @@ COMMENT ON TABLE usage_events_daily IS 'usage_events_daily is a daily rollup of
|
||||
|
||||
COMMENT ON COLUMN usage_events_daily.day IS 'The date of the summed usage events, always in UTC.';
|
||||
|
||||
CREATE TABLE user_chat_provider_keys (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
chat_provider_id uuid NOT NULL,
|
||||
api_key text NOT NULL,
|
||||
api_key_key_id text,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text))
|
||||
);
|
||||
|
||||
CREATE TABLE user_configs (
|
||||
user_id uuid NOT NULL,
|
||||
key character varying(256) NOT NULL,
|
||||
@@ -2789,7 +2808,8 @@ CREATE TABLE user_secrets (
|
||||
env_name text DEFAULT ''::text NOT NULL,
|
||||
file_path text DEFAULT ''::text NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
value_key_id text
|
||||
);
|
||||
|
||||
CREATE TABLE user_status_changes (
|
||||
@@ -3544,6 +3564,12 @@ ALTER TABLE ONLY usage_events_daily
|
||||
ALTER TABLE ONLY usage_events
|
||||
ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
|
||||
|
||||
ALTER TABLE ONLY user_configs
|
||||
ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
|
||||
|
||||
@@ -4254,6 +4280,15 @@ ALTER TABLE ONLY templates
|
||||
ALTER TABLE ONLY templates
|
||||
ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_configs
|
||||
ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -4272,6 +4307,9 @@ ALTER TABLE ONLY user_links
|
||||
ALTER TABLE ONLY user_secrets
|
||||
ADD CONSTRAINT user_secrets_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_secrets
|
||||
ADD CONSTRAINT user_secrets_value_key_id_fkey FOREIGN KEY (value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY user_status_changes
|
||||
ADD CONSTRAINT user_status_changes_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
|
||||
|
||||
|
||||
@@ -92,12 +92,16 @@ const (
|
||||
ForeignKeyTemplateVersionsTemplateID ForeignKeyConstraint = "template_versions_template_id_fkey" // ALTER TABLE ONLY template_versions ADD CONSTRAINT template_versions_template_id_fkey FOREIGN KEY (template_id) REFERENCES templates(id) ON DELETE CASCADE;
|
||||
ForeignKeyTemplatesCreatedBy ForeignKeyConstraint = "templates_created_by_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE RESTRICT;
|
||||
ForeignKeyTemplatesOrganizationID ForeignKeyConstraint = "templates_organization_id_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
|
||||
ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "user_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserLinksUserID ForeignKeyConstraint = "user_links_user_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserSecretsUserID ForeignKeyConstraint = "user_secrets_user_id_fkey" // ALTER TABLE ONLY user_secrets ADD CONSTRAINT user_secrets_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserSecretsValueKeyID ForeignKeyConstraint = "user_secrets_value_key_id_fkey" // ALTER TABLE ONLY user_secrets ADD CONSTRAINT user_secrets_value_key_id_fkey FOREIGN KEY (value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserStatusChangesUserID ForeignKeyConstraint = "user_status_changes_user_id_fkey" // ALTER TABLE ONLY user_status_changes ADD CONSTRAINT user_status_changes_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
|
||||
ForeignKeyWebpushSubscriptionsUserID ForeignKeyConstraint = "webpush_subscriptions_user_id_fkey" // ALTER TABLE ONLY webpush_subscriptions ADD CONSTRAINT webpush_subscriptions_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyWorkspaceAgentDevcontainersSubagentID ForeignKeyConstraint = "workspace_agent_devcontainers_subagent_id_fkey" // ALTER TABLE ONLY workspace_agent_devcontainers ADD CONSTRAINT workspace_agent_devcontainers_subagent_id_fkey FOREIGN KEY (subagent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chats DROP COLUMN last_injected_context;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chats ADD COLUMN last_injected_context JSONB;
|
||||
@@ -0,0 +1,4 @@
|
||||
-- Remove 'agents-access' from all users who have it.
|
||||
UPDATE users
|
||||
SET rbac_roles = array_remove(rbac_roles, 'agents-access')
|
||||
WHERE 'agents-access' = ANY(rbac_roles);
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Grant 'agents-access' to every user who has ever created a chat.
|
||||
UPDATE users
|
||||
SET rbac_roles = array_append(rbac_roles, 'agents-access')
|
||||
WHERE id IN (SELECT DISTINCT owner_id FROM chats)
|
||||
AND NOT ('agents-access' = ANY(rbac_roles));
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE aibridge_interceptions DROP COLUMN provider_name;
|
||||
@@ -0,0 +1,6 @@
|
||||
ALTER TABLE aibridge_interceptions ADD COLUMN provider_name TEXT NOT NULL DEFAULT '';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
|
||||
|
||||
-- Backfill existing records with the provider type as the provider name.
|
||||
UPDATE aibridge_interceptions SET provider_name = provider WHERE provider_name = '';
|
||||
@@ -0,0 +1,8 @@
|
||||
DROP TABLE IF EXISTS user_chat_provider_keys;
|
||||
|
||||
ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
DROP COLUMN IF EXISTS central_api_key_enabled,
|
||||
DROP COLUMN IF EXISTS allow_user_api_key,
|
||||
DROP COLUMN IF EXISTS allow_central_api_key_fallback;
|
||||
@@ -0,0 +1,24 @@
|
||||
ALTER TABLE chat_providers
|
||||
ADD COLUMN central_api_key_enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
ADD COLUMN allow_user_api_key BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
ADD COLUMN allow_central_api_key_fallback BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
ADD CONSTRAINT valid_credential_policy CHECK (
|
||||
(central_api_key_enabled OR allow_user_api_key) AND
|
||||
(
|
||||
NOT allow_central_api_key_fallback OR
|
||||
(central_api_key_enabled AND allow_user_api_key)
|
||||
)
|
||||
);
|
||||
|
||||
CREATE TABLE user_chat_provider_keys (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
chat_provider_id UUID NOT NULL REFERENCES chat_providers(id) ON DELETE CASCADE,
|
||||
api_key TEXT NOT NULL CHECK (api_key != ''),
|
||||
api_key_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (user_id, chat_provider_id)
|
||||
);
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE user_secrets
|
||||
DROP CONSTRAINT user_secrets_value_key_id_fkey,
|
||||
DROP COLUMN value_key_id;
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE user_secrets
|
||||
ADD COLUMN value_key_id TEXT;
|
||||
|
||||
ALTER TABLE ONLY user_secrets
|
||||
ADD CONSTRAINT user_secrets_value_key_id_fkey FOREIGN KEY (value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
@@ -877,3 +877,149 @@ func TestMigration000387MigrateTaskWorkspaces(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, antCount, "antagonist workspaces (deleted and regular) should not be migrated")
|
||||
}
|
||||
|
||||
func TestMigration000457ChatAccessRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const migrationVersion = 457
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
|
||||
// Migrate up to the migration before the one that grants
|
||||
// agents-access roles.
|
||||
next, err := migrations.Stepper(sqlDB)
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
version, more, err := next()
|
||||
require.NoError(t, err)
|
||||
if !more {
|
||||
t.Fatalf("migration %d not found", migrationVersion)
|
||||
}
|
||||
if version == migrationVersion-1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
|
||||
// Define test users.
|
||||
userWithChat := uuid.New() // Has a chat, no agents-access role.
|
||||
userAlreadyHasRole := uuid.New() // Has a chat and already has agents-access.
|
||||
userNoChat := uuid.New() // No chat at all.
|
||||
userWithChatAndRoles := uuid.New() // Has a chat and other existing roles.
|
||||
|
||||
now := time.Now().UTC().Truncate(time.Microsecond)
|
||||
|
||||
// We need a chat_provider and chat_model_config for the chats FK.
|
||||
providerID := uuid.New()
|
||||
modelConfigID := uuid.New()
|
||||
|
||||
tx, err := sqlDB.BeginTx(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
fixtures := []struct {
|
||||
query string
|
||||
args []any
|
||||
}{
|
||||
// Insert test users with varying rbac_roles.
|
||||
{
|
||||
`INSERT INTO users (id, username, email, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{userWithChat, "user-with-chat", "chat@test.com", []byte{}, now, now, "active", pq.StringArray{}, "password"},
|
||||
},
|
||||
{
|
||||
`INSERT INTO users (id, username, email, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{userAlreadyHasRole, "user-already-has-role", "already@test.com", []byte{}, now, now, "active", pq.StringArray{"agents-access"}, "password"},
|
||||
},
|
||||
{
|
||||
`INSERT INTO users (id, username, email, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{userNoChat, "user-no-chat", "nochat@test.com", []byte{}, now, now, "active", pq.StringArray{}, "password"},
|
||||
},
|
||||
{
|
||||
`INSERT INTO users (id, username, email, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{userWithChatAndRoles, "user-with-roles", "roles@test.com", []byte{}, now, now, "active", pq.StringArray{"template-admin"}, "password"},
|
||||
},
|
||||
// Insert a chat provider and model config for the chats FK.
|
||||
{
|
||||
`INSERT INTO chat_providers (id, provider, display_name, api_key, enabled, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
||||
[]any{providerID, "openai", "OpenAI", "", true, now, now},
|
||||
},
|
||||
{
|
||||
`INSERT INTO chat_model_configs (id, provider, model, display_name, enabled, context_limit, compression_threshold, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{modelConfigID, "openai", "gpt-4", "GPT 4", true, 100000, 70, now, now},
|
||||
},
|
||||
// Insert chats for users A, B, and D (not C).
|
||||
{
|
||||
`INSERT INTO chats (id, owner_id, last_model_config_id, title, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
[]any{uuid.New(), userWithChat, modelConfigID, "Chat A", now, now},
|
||||
},
|
||||
{
|
||||
`INSERT INTO chats (id, owner_id, last_model_config_id, title, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
[]any{uuid.New(), userAlreadyHasRole, modelConfigID, "Chat B", now, now},
|
||||
},
|
||||
{
|
||||
`INSERT INTO chats (id, owner_id, last_model_config_id, title, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
[]any{uuid.New(), userWithChatAndRoles, modelConfigID, "Chat D", now, now},
|
||||
},
|
||||
}
|
||||
|
||||
for i, f := range fixtures {
|
||||
_, err := tx.ExecContext(ctx, f.query, f.args...)
|
||||
require.NoError(t, err, "fixture %d", i)
|
||||
}
|
||||
require.NoError(t, tx.Commit())
|
||||
|
||||
// Run the migration.
|
||||
version, _, err := next()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, migrationVersion, version)
|
||||
|
||||
// Helper to get rbac_roles for a user.
|
||||
getRoles := func(t *testing.T, userID uuid.UUID) []string {
|
||||
t.Helper()
|
||||
var roles pq.StringArray
|
||||
err := sqlDB.QueryRowContext(ctx,
|
||||
"SELECT rbac_roles FROM users WHERE id = $1", userID,
|
||||
).Scan(&roles)
|
||||
require.NoError(t, err)
|
||||
return roles
|
||||
}
|
||||
|
||||
// Verify: user with chat gets agents-access.
|
||||
roles := getRoles(t, userWithChat)
|
||||
require.Contains(t, roles, "agents-access",
|
||||
"user with chat should get agents-access")
|
||||
|
||||
// Verify: user who already had agents-access has no duplicate.
|
||||
roles = getRoles(t, userAlreadyHasRole)
|
||||
count := 0
|
||||
for _, r := range roles {
|
||||
if r == "agents-access" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, count,
|
||||
"user who already had agents-access should not get a duplicate")
|
||||
|
||||
// Verify: user without chat does NOT get agents-access.
|
||||
roles = getRoles(t, userNoChat)
|
||||
require.NotContains(t, roles, "agents-access",
|
||||
"user without chat should not get agents-access")
|
||||
|
||||
// Verify: user with chat and existing roles gets agents-access
|
||||
// appended while preserving existing roles.
|
||||
roles = getRoles(t, userWithChatAndRoles)
|
||||
require.Contains(t, roles, "agents-access",
|
||||
"user with chat and other roles should get agents-access")
|
||||
require.Contains(t, roles, "template-admin",
|
||||
"existing roles should be preserved")
|
||||
}
|
||||
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
INSERT INTO user_chat_provider_keys (
|
||||
user_id,
|
||||
chat_provider_id,
|
||||
api_key,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
'0a8b2f84-b5a8-4c44-8c9f-e58c44a534a7',
|
||||
'fixture-test-key',
|
||||
'2025-01-01 00:00:00+00',
|
||||
'2025-01-01 00:00:00+00'
|
||||
FROM users
|
||||
ORDER BY created_at, id
|
||||
LIMIT 1;
|
||||
@@ -795,6 +795,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.Chat.AgentID,
|
||||
&i.Chat.PinOrder,
|
||||
&i.Chat.LastReadMessageID,
|
||||
&i.Chat.LastInjectedContext,
|
||||
&i.HasUnread); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -864,6 +865,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -995,8 +997,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis
|
||||
query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessions :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.AfterSessionID,
|
||||
arg.Offset,
|
||||
arg.Limit,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
@@ -1004,6 +1004,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
arg.Offset,
|
||||
arg.Limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1124,6 +1126,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, a
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+54
-37
@@ -4038,6 +4038,8 @@ type AIBridgeInterception struct {
|
||||
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
|
||||
// Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception's own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
// The provider instance name which may differ from provider when multiple instances of the same provider type exist.
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
@@ -4153,28 +4155,29 @@ type BoundaryUsageStat struct {
|
||||
}
|
||||
|
||||
type Chat struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
Archived bool `db:"archived" json:"archived"`
|
||||
LastError sql.NullString `db:"last_error" json:"last_error"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
Labels StringMap `db:"labels" json:"labels"`
|
||||
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
PinOrder int32 `db:"pin_order" json:"pin_order"`
|
||||
LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
Archived bool `db:"archived" json:"archived"`
|
||||
LastError sql.NullString `db:"last_error" json:"last_error"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
Labels StringMap `db:"labels" json:"labels"`
|
||||
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
PinOrder int32 `db:"pin_order" json:"pin_order"`
|
||||
LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"`
|
||||
LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
@@ -4261,12 +4264,15 @@ type ChatProvider struct {
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
|
||||
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
|
||||
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
|
||||
}
|
||||
|
||||
type ChatQueuedMessage struct {
|
||||
@@ -5219,6 +5225,16 @@ type User struct {
|
||||
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
|
||||
}
|
||||
|
||||
type UserChatProviderKey struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Key string `db:"key" json:"key"`
|
||||
@@ -5248,15 +5264,16 @@ type UserLink struct {
|
||||
}
|
||||
|
||||
type UserSecret struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
|
||||
}
|
||||
|
||||
// Tracks the history of user status changes
|
||||
|
||||
@@ -54,7 +54,7 @@ type sqlcQuerier interface {
|
||||
ActivityBumpWorkspace(ctx context.Context, arg ActivityBumpWorkspaceParams) error
|
||||
// AllUserIDs returns all UserIDs regardless of user status or deletion.
|
||||
AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error)
|
||||
ArchiveChatByID(ctx context.Context, id uuid.UUID) error
|
||||
ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// Archiving templates is a soft delete action, so is reversible.
|
||||
// Archiving prevents the version from being used and discovered
|
||||
// by listing.
|
||||
@@ -150,6 +150,7 @@ type sqlcQuerier interface {
|
||||
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
@@ -577,6 +578,7 @@ type sqlcQuerier interface {
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
|
||||
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
|
||||
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
// Returns the minimum (most restrictive) group limit for a user.
|
||||
@@ -788,6 +790,10 @@ type sqlcQuerier interface {
|
||||
// Returns paginated sessions with aggregated metadata, token counts, and
|
||||
// the most recent user prompt. A "session" is a logical grouping of
|
||||
// interceptions that share the same session_id (set by the client).
|
||||
//
|
||||
// Pagination-first strategy: identify the page of sessions cheaply via a
|
||||
// single GROUP BY scan, then do expensive lateral joins (tokens, prompts,
|
||||
// first-interception metadata) only for the ~page-size result set.
|
||||
ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error)
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
@@ -840,7 +846,7 @@ type sqlcQuerier interface {
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
// released when the transaction ends.
|
||||
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) error
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
|
||||
@@ -854,6 +860,11 @@ type sqlcQuerier interface {
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
// Updates the cached injected context parts (AGENTS.md +
|
||||
// skills) on the chat row. Called only when context changes
|
||||
// (first workspace attach or agent change). updated_at is
|
||||
// intentionally not touched to avoid reordering the chat list.
|
||||
UpdateChatLastInjectedContext(ctx context.Context, arg UpdateChatLastInjectedContextParams) (Chat, error)
|
||||
UpdateChatLastModelConfigByID(ctx context.Context, arg UpdateChatLastModelConfigByIDParams) (Chat, error)
|
||||
// Updates the last read message ID for a chat. This is used to track
|
||||
// which messages the owner has seen, enabling unread indicators.
|
||||
@@ -918,6 +929,7 @@ type sqlcQuerier interface {
|
||||
UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error
|
||||
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
|
||||
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
|
||||
UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
|
||||
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
|
||||
UpdateUserHashedOneTimePasscode(ctx context.Context, arg UpdateUserHashedOneTimePasscodeParams) error
|
||||
@@ -1006,6 +1018,7 @@ type sqlcQuerier interface {
|
||||
// used to store the data, and the minutes are summed for each user and template
|
||||
// combination. The result is stored in the template_usage_stats table.
|
||||
UpsertTemplateUsageStats(ctx context.Context) error
|
||||
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
|
||||
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
|
||||
UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error)
|
||||
|
||||
@@ -1251,16 +1251,21 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
owner := dbgen.User(t, db, database.User{
|
||||
RBACRoles: []string{rbac.RoleOwner().String()},
|
||||
})
|
||||
member := dbgen.User(t, db, database.User{})
|
||||
secondMember := dbgen.User(t, db, database.User{})
|
||||
member := dbgen.User(t, db, database.User{
|
||||
RBACRoles: pq.StringArray{rbac.RoleAgentsAccess().String()},
|
||||
})
|
||||
secondMember := dbgen.User(t, db, database.User{
|
||||
RBACRoles: pq.StringArray{rbac.RoleAgentsAccess().String()},
|
||||
})
|
||||
|
||||
// Create FK dependencies: a chat provider and model config.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1281,6 +1286,7 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
// Create 3 chats owned by owner.
|
||||
for i := range 3 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("owner chat %d", i+1),
|
||||
@@ -1291,6 +1297,7 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
// Create 2 chats owned by member.
|
||||
for i := range 2 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: member.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("member chat %d", i+1),
|
||||
@@ -1407,9 +1414,12 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
|
||||
// Use a dedicated user for pagination to avoid interference
|
||||
// with the other parallel subtests.
|
||||
paginationUser := dbgen.User(t, db, database.User{})
|
||||
paginationUser := dbgen.User(t, db, database.User{
|
||||
RBACRoles: pq.StringArray{rbac.RoleAgentsAccess().String()},
|
||||
})
|
||||
for i := range 7 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: paginationUser.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("pagination chat %d", i+1),
|
||||
@@ -9447,10 +9457,11 @@ func TestInsertChatMessages(t *testing.T) {
|
||||
provider := "openai"
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: provider,
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -9466,6 +9477,7 @@ func TestInsertChatMessages(t *testing.T) {
|
||||
)
|
||||
|
||||
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelConfigA.ID,
|
||||
Title: "test-chat-" + uuid.NewString(),
|
||||
@@ -9611,10 +9623,11 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
||||
|
||||
// A chat_providers row is required as a FK for model configs.
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -9635,6 +9648,7 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
||||
newChat := func(t *testing.T) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-chat-" + uuid.NewString(),
|
||||
@@ -9981,10 +9995,11 @@ func TestGetPRInsights(t *testing.T) {
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -10008,6 +10023,7 @@ func TestGetPRInsights(t *testing.T) {
|
||||
createChat := func(t *testing.T, store database.Store, userID, mcID uuid.UUID, title string) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := store.InsertChat(context.Background(), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: userID,
|
||||
LastModelConfigID: mcID,
|
||||
Title: title,
|
||||
@@ -10143,6 +10159,7 @@ func TestGetPRInsights(t *testing.T) {
|
||||
createChildChat := func(t *testing.T, store database.Store, userID, mcID, parentID, rootID uuid.UUID, title string) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := store.InsertChat(context.Background(), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: userID,
|
||||
LastModelConfigID: mcID,
|
||||
Title: title,
|
||||
@@ -10503,10 +10520,11 @@ func TestChatPinOrderQueries(t *testing.T) {
|
||||
// timed test context doesn't tick during DB init.
|
||||
bg := context.Background()
|
||||
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -10532,6 +10550,7 @@ func TestChatPinOrderQueries(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelCfgID,
|
||||
Title: title,
|
||||
@@ -10640,7 +10659,8 @@ func TestChatPinOrderQueries(t *testing.T) {
|
||||
}
|
||||
|
||||
// Archive the middle pin.
|
||||
require.NoError(t, db.ArchiveChatByID(ctx, second.ID))
|
||||
_, err := db.ArchiveChatByID(ctx, second.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archived chat should have pin_order cleared. Remaining
|
||||
// pins keep their original positions; the next mutation
|
||||
@@ -10681,10 +10701,11 @@ func TestChatLabels(t *testing.T) {
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -10711,6 +10732,7 @@ func TestChatLabels(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "labeled-chat",
|
||||
@@ -10733,6 +10755,7 @@ func TestChatLabels(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "no-labels-chat",
|
||||
@@ -10748,6 +10771,7 @@ func TestChatLabels(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "update-labels-chat",
|
||||
@@ -10788,6 +10812,7 @@ func TestChatLabels(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "original-title",
|
||||
@@ -10824,6 +10849,7 @@ func TestChatLabels(t *testing.T) {
|
||||
labelsJSON, err := json.Marshal(tc.labels)
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: tc.title,
|
||||
@@ -10887,10 +10913,11 @@ func TestChatHasUnread(t *testing.T) {
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -10909,6 +10936,7 @@ func TestChatHasUnread(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-chat-" + uuid.NewString(),
|
||||
|
||||
+483
-146
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
) VALUES (
|
||||
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
@@ -454,95 +454,91 @@ WHERE
|
||||
-- Returns paginated sessions with aggregated metadata, token counts, and
|
||||
-- the most recent user prompt. A "session" is a logical grouping of
|
||||
-- interceptions that share the same session_id (set by the client).
|
||||
WITH filtered_interceptions AS (
|
||||
--
|
||||
-- Pagination-first strategy: identify the page of sessions cheaply via a
|
||||
-- single GROUP BY scan, then do expensive lateral joins (tokens, prompts,
|
||||
-- first-interception metadata) only for the ~page-size result set.
|
||||
WITH cursor_pos AS (
|
||||
-- Resolve the cursor's started_at once, outside the HAVING clause,
|
||||
-- so the planner cannot accidentally re-evaluate it per group.
|
||||
SELECT MIN(aibridge_interceptions.started_at) AS started_at
|
||||
FROM aibridge_interceptions
|
||||
WHERE aibridge_interceptions.session_id = @after_session_id AND aibridge_interceptions.ended_at IS NOT NULL
|
||||
),
|
||||
session_page AS (
|
||||
-- Paginate at the session level first; only cheap aggregates here.
|
||||
SELECT
|
||||
aibridge_interceptions.*
|
||||
ai.session_id,
|
||||
ai.initiator_id,
|
||||
MIN(ai.started_at) AS started_at,
|
||||
MAX(ai.ended_at) AS ended_at,
|
||||
COUNT(*) FILTER (WHERE ai.thread_root_id IS NULL) AS threads
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
aibridge_interceptions ai
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
ai.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN ai.started_at >= @started_after::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN ai.started_at <= @started_before::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN ai.initiator_id = @initiator_id::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
WHEN @provider::text != '' THEN ai.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
WHEN @model::text != '' THEN ai.model = @model::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
|
||||
WHEN @client::text != '' THEN COALESCE(ai.client, 'Unknown') = @client::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
|
||||
WHEN @session_id::text != '' THEN ai.session_id = @session_id::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
),
|
||||
session_tokens AS (
|
||||
-- Aggregate token usage across all interceptions in each session.
|
||||
-- Group by (session_id, initiator_id) to avoid merging sessions from
|
||||
-- different users who happen to share the same client_session_id.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
-- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands.
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON fi.id = tu.interception_id
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
),
|
||||
session_root AS (
|
||||
-- Build one summary row per session. Group by (session_id, initiator_id)
|
||||
-- to avoid merging sessions from different users who happen to share the
|
||||
-- same client_session_id. The ARRAY_AGG with ORDER BY picks values from
|
||||
-- the chronologically first interception for fields that should represent
|
||||
-- the session as a whole (client, metadata). Threads are counted as
|
||||
-- distinct root interception IDs: an interception with a NULL
|
||||
-- thread_root_id is itself a thread root.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
(ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client,
|
||||
(ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata,
|
||||
ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers,
|
||||
ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models,
|
||||
MIN(fi.started_at) AS started_at,
|
||||
MAX(fi.ended_at) AS ended_at,
|
||||
COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads,
|
||||
-- Collect IDs for lateral prompt lookup.
|
||||
ARRAY_AGG(fi.id) AS interception_ids
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
ai.session_id, ai.initiator_id
|
||||
HAVING
|
||||
-- Cursor pagination: uses a composite (started_at, session_id)
|
||||
-- cursor to support keyset pagination. The less-than comparison
|
||||
-- matches the DESC sort order so rows after the cursor come
|
||||
-- later in results. The cursor value comes from cursor_pos to
|
||||
-- guarantee single evaluation.
|
||||
CASE
|
||||
WHEN @after_session_id::text != '' THEN (
|
||||
(MIN(ai.started_at), ai.session_id) < (
|
||||
(SELECT started_at FROM cursor_pos),
|
||||
@after_session_id::text
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
MIN(ai.started_at) DESC,
|
||||
ai.session_id DESC
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
)
|
||||
SELECT
|
||||
sr.session_id,
|
||||
sp.session_id,
|
||||
visible_users.id AS user_id,
|
||||
visible_users.username AS user_username,
|
||||
visible_users.name AS user_name,
|
||||
@@ -551,45 +547,48 @@ SELECT
|
||||
sr.models::text[] AS models,
|
||||
COALESCE(sr.client, '')::varchar(64) AS client,
|
||||
sr.metadata::jsonb AS metadata,
|
||||
sr.started_at::timestamptz AS started_at,
|
||||
sr.ended_at::timestamptz AS ended_at,
|
||||
sr.threads,
|
||||
sp.started_at::timestamptz AS started_at,
|
||||
sp.ended_at::timestamptz AS ended_at,
|
||||
sp.threads,
|
||||
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
|
||||
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
|
||||
COALESCE(slp.prompt, '') AS last_prompt
|
||||
FROM
|
||||
session_root sr
|
||||
session_page sp
|
||||
JOIN
|
||||
visible_users ON visible_users.id = sr.initiator_id
|
||||
LEFT JOIN
|
||||
session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id
|
||||
visible_users ON visible_users.id = sp.initiator_id
|
||||
LEFT JOIN LATERAL (
|
||||
-- Lateral join to efficiently fetch only the most recent user prompt
|
||||
-- across all interceptions in the session, avoiding a full aggregation.
|
||||
SELECT
|
||||
(ARRAY_AGG(ai.client ORDER BY ai.started_at, ai.id))[1] AS client,
|
||||
(ARRAY_AGG(ai.metadata ORDER BY ai.started_at, ai.id))[1] AS metadata,
|
||||
ARRAY_AGG(DISTINCT ai.provider ORDER BY ai.provider) AS providers,
|
||||
ARRAY_AGG(DISTINCT ai.model ORDER BY ai.model) AS models,
|
||||
ARRAY_AGG(ai.id) AS interception_ids
|
||||
FROM aibridge_interceptions ai
|
||||
WHERE ai.session_id = sp.session_id
|
||||
AND ai.initiator_id = sp.initiator_id
|
||||
AND ai.ended_at IS NOT NULL
|
||||
) sr ON true
|
||||
LEFT JOIN LATERAL (
|
||||
-- Aggregate tokens only for this session's interceptions.
|
||||
SELECT
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
FROM aibridge_token_usages tu
|
||||
WHERE tu.interception_id = ANY(sr.interception_ids)
|
||||
) st ON true
|
||||
LEFT JOIN LATERAL (
|
||||
-- Fetch only the most recent user prompt across all interceptions
|
||||
-- in the session.
|
||||
SELECT up.prompt
|
||||
FROM aibridge_user_prompts up
|
||||
WHERE up.interception_id = ANY(sr.interception_ids)
|
||||
ORDER BY up.created_at DESC, up.id DESC
|
||||
LIMIT 1
|
||||
) slp ON true
|
||||
WHERE
|
||||
-- Cursor pagination: uses a composite (started_at, session_id) cursor
|
||||
-- to support keyset pagination. The less-than comparison matches the
|
||||
-- DESC sort order so that rows after the cursor come later in results.
|
||||
CASE
|
||||
WHEN @after_session_id::text != '' THEN (
|
||||
(sr.started_at, sr.session_id) < (
|
||||
(SELECT started_at FROM session_root WHERE session_id = @after_session_id),
|
||||
@after_session_id::text
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
sr.started_at DESC,
|
||||
sr.session_id DESC
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
sp.started_at DESC,
|
||||
sp.session_id DESC
|
||||
;
|
||||
|
||||
-- name: ListAIBridgeSessionThreads :many
|
||||
|
||||
@@ -40,7 +40,10 @@ INSERT INTO chat_providers (
|
||||
base_url,
|
||||
api_key_key_id,
|
||||
created_by,
|
||||
enabled
|
||||
enabled,
|
||||
central_api_key_enabled,
|
||||
allow_user_api_key,
|
||||
allow_central_api_key_fallback
|
||||
) VALUES (
|
||||
@provider::text,
|
||||
@display_name::text,
|
||||
@@ -48,7 +51,10 @@ INSERT INTO chat_providers (
|
||||
@base_url::text,
|
||||
sqlc.narg('api_key_key_id')::text,
|
||||
sqlc.narg('created_by')::uuid,
|
||||
@enabled::boolean
|
||||
@enabled::boolean,
|
||||
@central_api_key_enabled::boolean,
|
||||
@allow_user_api_key::boolean,
|
||||
@allow_central_api_key_fallback::boolean
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -62,6 +68,9 @@ SET
|
||||
base_url = @base_url::text,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
enabled = @enabled::boolean,
|
||||
central_api_key_enabled = @central_api_key_enabled::boolean,
|
||||
allow_user_api_key = @allow_user_api_key::boolean,
|
||||
allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
|
||||
@@ -1,9 +1,24 @@
|
||||
-- name: ArchiveChatByID :exec
|
||||
UPDATE chats SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = @id OR root_chat_id = @id;
|
||||
-- name: ArchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
SELECT *
|
||||
FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: UnarchiveChatByID :exec
|
||||
UPDATE chats SET archived = false, updated_at = NOW() WHERE id = @id::uuid;
|
||||
-- name: UnarchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
SELECT *
|
||||
FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: PinChatByID :exec
|
||||
WITH target_chat AS (
|
||||
@@ -377,6 +392,7 @@ INSERT INTO chats (
|
||||
last_model_config_id,
|
||||
title,
|
||||
mode,
|
||||
status,
|
||||
mcp_server_ids,
|
||||
labels
|
||||
) VALUES (
|
||||
@@ -389,6 +405,7 @@ INSERT INTO chats (
|
||||
@last_model_config_id::uuid,
|
||||
@title::text,
|
||||
sqlc.narg('mode')::chat_mode,
|
||||
@status::chat_status,
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb)
|
||||
)
|
||||
@@ -528,6 +545,17 @@ WHERE
|
||||
id = @id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatLastInjectedContext :one
|
||||
-- Updates the cached injected context parts (AGENTS.md +
|
||||
-- skills) on the chat row. Called only when context changes
|
||||
-- (first workspace attach or agent change). updated_at is
|
||||
-- intentionally not touched to avoid reordering the chat list.
|
||||
UPDATE chats SET
|
||||
last_injected_context = sqlc.narg('last_injected_context')::jsonb
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatMCPServerIDs :one
|
||||
UPDATE
|
||||
chats
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
-- name: GetUserChatProviderKeys :many
|
||||
SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC;
|
||||
|
||||
-- name: UpsertUserChatProviderKey :one
|
||||
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
|
||||
VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text)
|
||||
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
|
||||
api_key = @api_key,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
updated_at = NOW()
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateUserChatProviderKey :one
|
||||
UPDATE user_chat_provider_keys
|
||||
SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW()
|
||||
WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserChatProviderKey :exec
|
||||
DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id;
|
||||
@@ -90,6 +90,8 @@ const (
|
||||
UniqueTemplatesPkey UniqueConstraint = "templates_pkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id);
|
||||
UniqueUsageEventsDailyPkey UniqueConstraint = "usage_events_daily_pkey" // ALTER TABLE ONLY usage_events_daily ADD CONSTRAINT usage_events_daily_pkey PRIMARY KEY (day, event_type);
|
||||
UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
|
||||
UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
|
||||
UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
|
||||
UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
|
||||
UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id);
|
||||
UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type);
|
||||
|
||||
+488
-138
@@ -56,7 +56,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
chatDiffStatusTTL = gitsync.DiffStatusTTL
|
||||
chatStreamBatchSize = 256
|
||||
|
||||
chatContextLimitModelConfigKey = "context_limit"
|
||||
@@ -67,11 +66,12 @@ const (
|
||||
maxSystemPromptLenBytes = 131072 // 128 KiB
|
||||
)
|
||||
|
||||
// chatGitRef holds the branch and remote origin reported by the
|
||||
// workspace agent during a git operation.
|
||||
// chatGitRef holds the branch, remote origin, and optional chat
|
||||
// ID reported by the workspace agent during a git operation.
|
||||
type chatGitRef struct {
|
||||
Branch string
|
||||
RemoteOrigin string
|
||||
ChatID uuid.UUID
|
||||
}
|
||||
|
||||
type chatRepositoryRef struct {
|
||||
@@ -393,6 +393,11 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String())) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.CreateChatRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
@@ -498,6 +503,10 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to create chat.",
|
||||
Detail: err.Error(),
|
||||
@@ -511,6 +520,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
//nolint:gocritic // System context required to read enabled chat models.
|
||||
systemCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
|
||||
@@ -546,14 +556,24 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
||||
configuredProviders := make(
|
||||
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
||||
)
|
||||
enabledProviderNames := make(map[string]struct{}, len(enabledProviders))
|
||||
for _, provider := range enabledProviders {
|
||||
configuredProviders = append(
|
||||
configuredProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
ProviderID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserAPIKey: provider.AllowUserApiKey,
|
||||
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
},
|
||||
)
|
||||
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
enabledProviderNames[normalizedProvider] = struct{}{}
|
||||
}
|
||||
configuredModels := make(
|
||||
[]chatprovider.ConfiguredModel, 0, len(enabledModels),
|
||||
@@ -566,18 +586,38 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
keys := chatprovider.MergeProviderAPIKeys(
|
||||
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
userKeyRows, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to load user chat provider keys.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
userKeys := make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
|
||||
for _, userKey := range userKeyRows {
|
||||
userKeys = append(userKeys, chatprovider.UserProviderKey{
|
||||
ChatProviderID: userKey.ChatProviderID,
|
||||
APIKey: userKey.APIKey,
|
||||
})
|
||||
}
|
||||
|
||||
_, providerAvailability := chatprovider.ResolveUserProviderKeys(
|
||||
ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
configuredProviders,
|
||||
userKeys,
|
||||
)
|
||||
catalog := chatprovider.NewModelCatalog(keys)
|
||||
catalog := chatprovider.NewModelCatalog()
|
||||
var response codersdk.ChatModelsResponse
|
||||
if configured, ok := catalog.ListConfiguredModels(
|
||||
configuredProviders, configuredModels,
|
||||
configuredProviders, configuredModels, providerAvailability, enabledProviderNames,
|
||||
); ok {
|
||||
response = configured
|
||||
} else {
|
||||
response = catalog.ListConfiguredProviderAvailability(configuredProviders)
|
||||
response = catalog.ListConfiguredProviderAvailability(
|
||||
providerAvailability,
|
||||
enabledProviderNames,
|
||||
)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
@@ -616,6 +656,10 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
EndDate: endDate,
|
||||
})
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
@@ -626,6 +670,10 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
EndDate: endDate,
|
||||
})
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
@@ -636,6 +684,10 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
EndDate: endDate,
|
||||
})
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
@@ -1222,10 +1274,18 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
diffStatus, err := api.resolveChatDiffStatus(ctx, chat)
|
||||
if err != nil {
|
||||
// Log but don't fail - diff status is supplementary.
|
||||
api.Logger.Error(ctx, "failed to resolve chat diff status",
|
||||
// Use the cached diff status from the database rather than
|
||||
// resolving it inline. Inline resolution calls out to the
|
||||
// git provider API (e.g. GitHub) on every request which
|
||||
// blocks the response for 200-800ms. The background gitsync
|
||||
// worker keeps the cached status fresh.
|
||||
var diffStatus *database.ChatDiffStatus
|
||||
status, err := api.Database.GetChatDiffStatusByChatID(ctx, chat.ID)
|
||||
switch {
|
||||
case err == nil:
|
||||
diffStatus = &status
|
||||
case !xerrors.Is(err, sql.ErrNoRows):
|
||||
api.Logger.Error(ctx, "failed to get cached chat diff status",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
@@ -1620,20 +1680,20 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify active
|
||||
// subscribers. Fall back to direct DB for the simple
|
||||
// archive flag — no streaming state is involved.
|
||||
// Use chatDaemon when available so it can interrupt active
|
||||
// processing before broadcasting archive state. Fall back to
|
||||
// direct DB when no daemon is running.
|
||||
if archived {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
_, err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
} else {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.UnarchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
_, err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -2337,68 +2397,6 @@ func chatWorkspaceAuditStatus(err error) int {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func (api *API) resolveChatDiffStatus(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
) (*database.ChatDiffStatus, error) {
|
||||
status, found, err := api.getCachedChatDiffStatus(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
reference, err := api.resolveChatDiffReference(ctx, chat, found, status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reference.PullRequestURL != "" {
|
||||
if !found || !strings.EqualFold(strings.TrimSpace(status.Url.String), reference.PullRequestURL) {
|
||||
status, err = api.upsertChatDiffStatusReference(ctx, chat.ID, reference.PullRequestURL, now.Add(-time.Second))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, nil //nolint:nilnil // Callers handle nil status explicitly.
|
||||
}
|
||||
if !chatDiffStatusIsStale(status, now) {
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// Use the same refresh pipeline as the background worker
|
||||
// so both paths share identical provider/token resolution.
|
||||
refreshed, err := api.gitSyncWorker.RefreshChat(
|
||||
ctx, status, chat.OwnerID,
|
||||
)
|
||||
if err == nil && refreshed != nil {
|
||||
return refreshed, nil
|
||||
}
|
||||
if err == nil {
|
||||
// No PR exists yet; return what we have.
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
api.Logger.Warn(ctx, "failed to refresh chat diff status",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
|
||||
backoffStatus, backoffErr := api.upsertChatDiffStatusReference(ctx, chat.ID, reference.PullRequestURL, now.Add(chatDiffStatusTTL))
|
||||
if backoffErr != nil {
|
||||
api.Logger.Warn(ctx, "failed to extend chat diff status stale timestamp",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(backoffErr),
|
||||
)
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
return &backoffStatus, nil
|
||||
}
|
||||
|
||||
func (api *API) resolveChatDiffContents(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
@@ -2665,13 +2663,6 @@ func (api *API) resolveGitProvider(origin string) gitprovider.Provider {
|
||||
return gp
|
||||
}
|
||||
|
||||
func chatDiffStatusIsStale(status database.ChatDiffStatus, now time.Time) bool {
|
||||
if !status.RefreshedAt.Valid {
|
||||
return true
|
||||
}
|
||||
return !status.StaleAt.After(now)
|
||||
}
|
||||
|
||||
func (api *API) resolveChatGitAccessToken(
|
||||
ctx context.Context,
|
||||
userID uuid.UUID,
|
||||
@@ -3966,9 +3957,13 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
||||
)
|
||||
for _, provider := range enabledProviders {
|
||||
normalizedProvider := normalizeChatProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
enabledConfiguredProviders = append(
|
||||
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: provider.Provider,
|
||||
Provider: normalizedProvider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
},
|
||||
@@ -3976,7 +3971,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
effectiveKeys := chatprovider.MergeProviderAPIKeys(
|
||||
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
enabledConfiguredProviders,
|
||||
)
|
||||
effectiveKeys = chatprovider.MergeProviderAPIKeys(
|
||||
@@ -3992,7 +3987,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
resp,
|
||||
convertChatProviderConfig(
|
||||
configured,
|
||||
effectiveKeys.APIKey(provider) != "",
|
||||
api.hasEffectiveProviderAPIKey(ctx, configured),
|
||||
codersdk.ChatProviderConfigSourceDatabase,
|
||||
),
|
||||
)
|
||||
@@ -4008,13 +4003,16 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
resp = append(resp, codersdk.ChatProviderConfig{
|
||||
ID: uuid.Nil,
|
||||
Provider: provider,
|
||||
DisplayName: chatprovider.ProviderDisplayName(provider),
|
||||
Enabled: enabled,
|
||||
HasAPIKey: hasAPIKey,
|
||||
BaseURL: effectiveKeys.BaseURL(provider),
|
||||
Source: source,
|
||||
ID: uuid.Nil,
|
||||
Provider: provider,
|
||||
DisplayName: chatprovider.ProviderDisplayName(provider),
|
||||
Enabled: enabled,
|
||||
HasAPIKey: hasAPIKey,
|
||||
CentralAPIKeyEnabled: true,
|
||||
AllowUserAPIKey: false,
|
||||
AllowCentralAPIKeyFallback: false,
|
||||
BaseURL: effectiveKeys.BaseURL(provider),
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4024,6 +4022,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
var inserted database.ChatProvider
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
@@ -4043,6 +4042,14 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateChatProviderAPIKeySize(strings.TrimSpace(req.APIKey)); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key too large.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
enabled := true
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
@@ -4056,14 +4063,57 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
inserted, err := api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: strings.TrimSpace(req.DisplayName),
|
||||
APIKey: strings.TrimSpace(req.APIKey),
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
||||
Enabled: enabled,
|
||||
centralAPIKeyEnabled := true
|
||||
if req.CentralAPIKeyEnabled != nil {
|
||||
centralAPIKeyEnabled = *req.CentralAPIKeyEnabled
|
||||
}
|
||||
allowUserAPIKey := false
|
||||
if req.AllowUserAPIKey != nil {
|
||||
allowUserAPIKey = *req.AllowUserAPIKey
|
||||
}
|
||||
allowCentralAPIKeyFallback := false
|
||||
if req.AllowCentralAPIKeyFallback != nil {
|
||||
allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback
|
||||
}
|
||||
|
||||
if err := validateChatProviderCredentialPolicy(
|
||||
centralAPIKeyEnabled,
|
||||
allowUserAPIKey,
|
||||
allowCentralAPIKeyFallback,
|
||||
); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid credential policy.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateChatProviderCentralAPIKey(
|
||||
centralAPIKeyEnabled,
|
||||
api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{
|
||||
Provider: provider,
|
||||
APIKey: strings.TrimSpace(req.APIKey),
|
||||
BaseUrl: baseURL,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
}, uuid.Nil),
|
||||
); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
inserted, err = api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: strings.TrimSpace(req.DisplayName),
|
||||
APIKey: strings.TrimSpace(req.APIKey),
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
||||
Enabled: enabled,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
AllowUserApiKey: allowUserAPIKey,
|
||||
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
|
||||
})
|
||||
if err != nil {
|
||||
switch {
|
||||
@@ -4104,6 +4154,10 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
var (
|
||||
existing database.ChatProvider
|
||||
updated database.ChatProvider
|
||||
)
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
@@ -4145,7 +4199,17 @@ func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKey := existing.APIKey
|
||||
apiKeyKeyID := existing.ApiKeyKeyID
|
||||
if req.APIKey != nil {
|
||||
apiKey = strings.TrimSpace(*req.APIKey)
|
||||
trimmedAPIKey := strings.TrimSpace(*req.APIKey)
|
||||
if trimmedAPIKey != "" {
|
||||
if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key too large.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
apiKey = trimmedAPIKey
|
||||
apiKeyKeyID = sql.NullString{}
|
||||
}
|
||||
baseURL := existing.BaseUrl
|
||||
@@ -4160,13 +4224,57 @@ func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
updated, err := api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: displayName,
|
||||
APIKey: apiKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: apiKeyKeyID,
|
||||
Enabled: enabled,
|
||||
ID: existing.ID,
|
||||
centralAPIKeyEnabled := existing.CentralApiKeyEnabled
|
||||
if req.CentralAPIKeyEnabled != nil {
|
||||
centralAPIKeyEnabled = *req.CentralAPIKeyEnabled
|
||||
}
|
||||
allowUserAPIKey := existing.AllowUserApiKey
|
||||
if req.AllowUserAPIKey != nil {
|
||||
allowUserAPIKey = *req.AllowUserAPIKey
|
||||
}
|
||||
allowCentralAPIKeyFallback := existing.AllowCentralApiKeyFallback
|
||||
if req.AllowCentralAPIKeyFallback != nil {
|
||||
allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback
|
||||
}
|
||||
|
||||
if err := validateChatProviderCredentialPolicy(
|
||||
centralAPIKeyEnabled,
|
||||
allowUserAPIKey,
|
||||
allowCentralAPIKeyFallback,
|
||||
); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid credential policy.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateChatProviderCentralAPIKey(
|
||||
centralAPIKeyEnabled,
|
||||
api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{
|
||||
ID: existing.ID,
|
||||
Provider: existing.Provider,
|
||||
APIKey: apiKey,
|
||||
BaseUrl: baseURL,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
}, existing.ID),
|
||||
); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
updated, err = api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: displayName,
|
||||
APIKey: apiKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: apiKeyKeyID,
|
||||
Enabled: enabled,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
AllowUserApiKey: allowUserAPIKey,
|
||||
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
|
||||
ID: existing.ID,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -4237,6 +4345,169 @@ func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
//nolint:gocritic // Non-admin users need to read provider configs to manage their own chat credentials.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
providers, err := api.Database.GetChatProviders(chatdCtx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list chat providers.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userKeys, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list user chat provider keys.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
hasUserAPIKeyByProviderID := make(map[uuid.UUID]bool, len(userKeys))
|
||||
for _, userKey := range userKeys {
|
||||
hasUserAPIKeyByProviderID[userKey.ChatProviderID] = true
|
||||
}
|
||||
|
||||
resp := make([]codersdk.UserChatProviderConfig, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
if !provider.Enabled || !provider.AllowUserApiKey {
|
||||
continue
|
||||
}
|
||||
hasUserAPIKey := hasUserAPIKeyByProviderID[provider.ID]
|
||||
hasCentralAPIKeyFallback := provider.Enabled &&
|
||||
provider.AllowCentralApiKeyFallback &&
|
||||
api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil)
|
||||
resp = append(
|
||||
resp,
|
||||
convertUserChatProviderConfig(
|
||||
provider,
|
||||
hasUserAPIKey,
|
||||
hasCentralAPIKeyFallback,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (api *API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
providerID, ok := parseChatProviderID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:gocritic // Non-admin users need to validate provider availability before storing their own key.
|
||||
provider, err := api.Database.GetChatProviderByID(dbauthz.AsChatd(ctx), providerID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat provider.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !provider.Enabled {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Provider is disabled.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !provider.AllowUserApiKey {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Provider does not allow user API keys.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.CreateUserChatProviderKeyRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
trimmedAPIKey := strings.TrimSpace(req.APIKey)
|
||||
if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key too large.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if trimmedAPIKey == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := api.Database.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: apiKey.UserID,
|
||||
ChatProviderID: providerID,
|
||||
APIKey: trimmedAPIKey,
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
}); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to save user chat provider key.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
hasCentralAPIKeyFallback := provider.Enabled &&
|
||||
provider.AllowCentralApiKeyFallback &&
|
||||
api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil)
|
||||
httpapi.Write(
|
||||
ctx,
|
||||
rw,
|
||||
http.StatusOK,
|
||||
convertUserChatProviderConfig(
|
||||
provider,
|
||||
true,
|
||||
hasCentralAPIKeyFallback,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (api *API) deleteUserChatProviderKey(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
providerID, ok := parseChatProviderID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.Database.DeleteUserChatProviderKey(ctx, database.DeleteUserChatProviderKeyParams{
|
||||
UserID: apiKey.UserID,
|
||||
ChatProviderID: providerID,
|
||||
}); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to delete user chat provider key.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
@@ -4777,15 +5048,37 @@ func convertChatProviderConfig(
|
||||
}
|
||||
|
||||
return codersdk.ChatProviderConfig{
|
||||
ID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
DisplayName: displayName,
|
||||
Enabled: provider.Enabled,
|
||||
HasAPIKey: hasAPIKey,
|
||||
BaseURL: strings.TrimSpace(provider.BaseUrl),
|
||||
Source: source,
|
||||
CreatedAt: provider.CreatedAt,
|
||||
UpdatedAt: provider.UpdatedAt,
|
||||
ID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
DisplayName: displayName,
|
||||
Enabled: provider.Enabled,
|
||||
HasAPIKey: hasAPIKey,
|
||||
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserAPIKey: provider.AllowUserApiKey,
|
||||
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
BaseURL: strings.TrimSpace(provider.BaseUrl),
|
||||
Source: source,
|
||||
CreatedAt: provider.CreatedAt,
|
||||
UpdatedAt: provider.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func convertUserChatProviderConfig(
|
||||
provider database.ChatProvider,
|
||||
hasUserAPIKey bool,
|
||||
hasCentralAPIKeyFallback bool,
|
||||
) codersdk.UserChatProviderConfig {
|
||||
displayName := strings.TrimSpace(provider.DisplayName)
|
||||
if displayName == "" {
|
||||
displayName = chatprovider.ProviderDisplayName(provider.Provider)
|
||||
}
|
||||
|
||||
return codersdk.UserChatProviderConfig{
|
||||
ProviderID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
DisplayName: displayName,
|
||||
HasUserAPIKey: hasUserAPIKey,
|
||||
HasCentralAPIKeyFallback: hasCentralAPIKeyFallback,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4944,26 +5237,80 @@ func chatProviderValidationDetail() string {
|
||||
return "Provider must be one of: " + strings.Join(chatprovider.SupportedProviders(), ", ") + "."
|
||||
}
|
||||
|
||||
func chatProviderAPIKeysFromDeploymentValues(
|
||||
deploymentValues *codersdk.DeploymentValues,
|
||||
) chatprovider.ProviderAPIKeys {
|
||||
_ = deploymentValues
|
||||
// For now, we'll just manage configs in the UI.
|
||||
// We should probably not be reusing the AI bridge configs anyways.
|
||||
return chatprovider.ProviderAPIKeys{
|
||||
// OpenAI: deploymentValues.AI.BridgeConfig.OpenAI.Key.Value(),
|
||||
// Anthropic: deploymentValues.AI.BridgeConfig.Anthropic.Key.Value(),
|
||||
// BaseURLByProvider: map[string]string{
|
||||
// "openai": deploymentValues.AI.BridgeConfig.OpenAI.BaseURL.Value(),
|
||||
// "anthropic": deploymentValues.AI.BridgeConfig.Anthropic.BaseURL.Value(),
|
||||
// },
|
||||
const maxChatProviderAPIKeySize = 10240 // 10 KB
|
||||
|
||||
func validateChatProviderAPIKeySize(apiKey string) error {
|
||||
if len(apiKey) > maxChatProviderAPIKeySize {
|
||||
return xerrors.Errorf("API key exceeds maximum size of %d bytes", maxChatProviderAPIKeySize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:revive // This helper validates the explicit credential policy tuple.
|
||||
func validateChatProviderCredentialPolicy(
|
||||
centralEnabled, allowUserKey, allowFallback bool,
|
||||
) error {
|
||||
if !centralEnabled && !allowUserKey {
|
||||
return xerrors.New(
|
||||
"At least one credential source must be enabled: central API key or user API key.",
|
||||
)
|
||||
}
|
||||
if allowFallback && !centralEnabled {
|
||||
return xerrors.New(
|
||||
"Central API key fallback requires central API key to be enabled.",
|
||||
)
|
||||
}
|
||||
if allowFallback && !allowUserKey {
|
||||
return xerrors.New(
|
||||
"Central API key fallback requires user API key to be enabled.",
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:revive // This helper validates central-key requirements.
|
||||
func validateChatProviderCentralAPIKey(
|
||||
centralEnabled bool,
|
||||
hasCentralAPIKey bool,
|
||||
) error {
|
||||
if centralEnabled && !hasCentralAPIKey {
|
||||
return xerrors.New(
|
||||
"API key is required when central API key is enabled.",
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChatProviderAPIKeysFromDeploymentValues returns deployment-backed chat
|
||||
// provider API keys.
|
||||
func ChatProviderAPIKeysFromDeploymentValues(
|
||||
_ *codersdk.DeploymentValues,
|
||||
) chatprovider.ProviderAPIKeys {
|
||||
// AI bridge deployment config is intentionally not reused for chat
|
||||
// provider credentials. Bridge keys serve the AI task subsystem and
|
||||
// should not silently broaden into chat execution paths.
|
||||
return chatprovider.ProviderAPIKeys{}
|
||||
}
|
||||
|
||||
func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider database.ChatProvider) bool {
|
||||
return api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil)
|
||||
}
|
||||
|
||||
func (api *API) hasEffectiveCentralProviderAPIKey(
|
||||
ctx context.Context,
|
||||
provider database.ChatProvider,
|
||||
excludeProviderID uuid.UUID,
|
||||
) bool {
|
||||
if !provider.CentralApiKeyEnabled {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(provider.APIKey) != "" {
|
||||
return true
|
||||
}
|
||||
deploymentKeys := ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues)
|
||||
if deploymentKeys.APIKey(provider.Provider) != "" {
|
||||
return true
|
||||
}
|
||||
if api.chatDaemon == nil {
|
||||
return false
|
||||
}
|
||||
@@ -4985,6 +5332,9 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas
|
||||
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
||||
)
|
||||
for _, configured := range enabledProviders {
|
||||
if excludeProviderID != uuid.Nil && configured.ID == excludeProviderID {
|
||||
continue
|
||||
}
|
||||
enabledConfiguredProviders = append(
|
||||
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: configured.Provider,
|
||||
@@ -4995,7 +5345,7 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas
|
||||
}
|
||||
|
||||
effectiveKeys := chatprovider.MergeProviderAPIKeys(
|
||||
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
deploymentKeys,
|
||||
enabledConfiguredProviders,
|
||||
)
|
||||
return effectiveKeys.APIKey(provider.Provider) != ""
|
||||
|
||||
+1290
-146
File diff suppressed because it is too large
Load Diff
@@ -39,13 +39,14 @@ func TestChatParam(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.InsertChatProvider(context.Background(), database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
BaseUrl: "https://api.openai.com/v1",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
BaseUrl: "https://api.openai.com/v1",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -62,6 +63,7 @@ func TestChatParam(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(context.Background(), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: ownerID,
|
||||
WorkspaceID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
|
||||
+29
-16
@@ -2,6 +2,7 @@ package prebuilds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -22,7 +23,11 @@ type PubsubWorkspaceClaimPublisher struct {
|
||||
|
||||
func (p PubsubWorkspaceClaimPublisher) PublishWorkspaceClaim(claim agentsdk.ReinitializationEvent) error {
|
||||
channel := agentsdk.PrebuildClaimedChannel(claim.WorkspaceID)
|
||||
if err := p.ps.Publish(channel, []byte(claim.Reason)); err != nil {
|
||||
payload, err := json.Marshal(claim)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal claim event: %w", err)
|
||||
}
|
||||
if err := p.ps.Publish(channel, payload); err != nil {
|
||||
return xerrors.Errorf("failed to trigger prebuilt workspace agent reinitialization: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -37,33 +42,41 @@ type PubsubWorkspaceClaimListener struct {
|
||||
ps pubsub.Pubsub
|
||||
}
|
||||
|
||||
// ListenForWorkspaceClaims subscribes to a pubsub channel and sends any received events on the chan that it returns.
|
||||
// pubsub.Pubsub does not communicate when its last callback has been called after it has been closed. As such the chan
|
||||
// returned by this method is never closed. Call the returned cancel() function to close the subscription when it is no longer needed.
|
||||
// cancel() will be called if ctx expires or is canceled.
|
||||
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID, reinitEvents chan<- agentsdk.ReinitializationEvent) (func(), error) {
|
||||
// ListenForWorkspaceClaims subscribes to a pubsub channel and returns a
|
||||
// receive-only channel that emits claim events for the given workspace.
|
||||
// The returned channel is owned by this function and is never closed,
|
||||
// because pubsub.Pubsub does not guarantee that all in-flight callbacks
|
||||
// have returned after unsubscribe. Call the returned cancel function to
|
||||
// unsubscribe when events are no longer needed; cancel is also called
|
||||
// automatically if ctx expires or is canceled.
|
||||
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID) (<-chan agentsdk.ReinitializationEvent, func(), error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return func() {}, ctx.Err()
|
||||
return nil, func() {}, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, reason []byte) {
|
||||
claim := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspaceID,
|
||||
Reason: agentsdk.ReinitializationReason(reason),
|
||||
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
|
||||
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, payload []byte) {
|
||||
var event agentsdk.ReinitializationEvent
|
||||
if err := json.Unmarshal(payload, &event); err != nil {
|
||||
// Rolling upgrade: old publishers send the raw reason
|
||||
// string instead of JSON.
|
||||
event = agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspaceID,
|
||||
Reason: agentsdk.ReinitializationReason(payload),
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-inner.Done():
|
||||
return
|
||||
case reinitEvents <- claim:
|
||||
case reinitEvents <- event:
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return func() {}, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
|
||||
return nil, func() {}, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
@@ -78,5 +91,5 @@ func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Conte
|
||||
cancel()
|
||||
}()
|
||||
|
||||
return cancel, nil
|
||||
return reinitEvents, cancel, nil
|
||||
}
|
||||
|
||||
@@ -25,24 +25,26 @@ func TestPubsubWorkspaceClaimPublisher(t *testing.T) {
|
||||
logger := testutil.Logger(t)
|
||||
ps := pubsub.NewInMemory()
|
||||
workspaceID := uuid.New()
|
||||
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
publisher := prebuilds.NewPubsubWorkspaceClaimPublisher(ps)
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, logger)
|
||||
|
||||
cancel, err := listener.ListenForWorkspaceClaims(ctx, workspaceID, reinitEvents)
|
||||
events, cancel, err := listener.ListenForWorkspaceClaims(ctx, workspaceID)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
userID := uuid.New()
|
||||
claim := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspaceID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
OwnerID: userID,
|
||||
}
|
||||
err = publisher.PublishWorkspaceClaim(claim)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotEvent := testutil.RequireReceive(ctx, t, reinitEvents)
|
||||
gotEvent := testutil.RequireReceive(ctx, t, events)
|
||||
require.Equal(t, workspaceID, gotEvent.WorkspaceID)
|
||||
require.Equal(t, claim.Reason, gotEvent.Reason)
|
||||
require.Equal(t, userID, gotEvent.OwnerID)
|
||||
})
|
||||
|
||||
t.Run("fail to publish claim", func(t *testing.T) {
|
||||
@@ -69,10 +71,8 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
ps := pubsub.NewInMemory()
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
|
||||
|
||||
claims := make(chan agentsdk.ReinitializationEvent, 1) // Buffer to avoid messing with goroutines in the rest of the test
|
||||
|
||||
workspaceID := uuid.New()
|
||||
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
|
||||
events, cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID)
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
|
||||
@@ -84,9 +84,10 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
|
||||
// Verify we receive the claim
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
claim := testutil.RequireReceive(ctx, t, claims)
|
||||
claim := testutil.RequireReceive(ctx, t, events)
|
||||
require.Equal(t, workspaceID, claim.WorkspaceID)
|
||||
require.Equal(t, reason, claim.Reason)
|
||||
require.Equal(t, uuid.Nil, claim.OwnerID)
|
||||
})
|
||||
|
||||
t.Run("ignores claim events for other workspaces", func(t *testing.T) {
|
||||
@@ -95,10 +96,9 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
ps := pubsub.NewInMemory()
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
|
||||
|
||||
claims := make(chan agentsdk.ReinitializationEvent)
|
||||
workspaceID := uuid.New()
|
||||
otherWorkspaceID := uuid.New()
|
||||
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
|
||||
events, cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID)
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
|
||||
@@ -109,7 +109,7 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
|
||||
// Verify we don't receive the claim
|
||||
select {
|
||||
case <-claims:
|
||||
case <-events:
|
||||
t.Fatal("received claim for wrong workspace")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected - no claim received
|
||||
@@ -119,11 +119,10 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
t.Run("communicates the error if it can't subscribe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
claims := make(chan agentsdk.ReinitializationEvent)
|
||||
ps := &brokenPubsub{}
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
|
||||
|
||||
_, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New(), claims)
|
||||
_, _, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New())
|
||||
require.ErrorContains(t, err, "failed to subscribe to prebuild claimed channel")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2539,6 +2539,7 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro
|
||||
err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
OwnerID: workspace.OwnerID,
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err))
|
||||
|
||||
@@ -51,7 +51,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/usage/usagetypes"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
@@ -2787,8 +2786,7 @@ func TestCompleteJob(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// GIVEN something is listening to process workspace reinitialization:
|
||||
reinitChan := make(chan agentsdk.ReinitializationEvent, 1) // Buffered to simplify test structure
|
||||
cancel, err := agplprebuilds.NewPubsubWorkspaceClaimListener(ps, testutil.Logger(t)).ListenForWorkspaceClaims(ctx, workspace.ID, reinitChan)
|
||||
reinitChan, cancel, err := agplprebuilds.NewPubsubWorkspaceClaimListener(ps, testutil.Logger(t)).ListenForWorkspaceClaims(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
|
||||
+38
-4
@@ -21,6 +21,7 @@ const (
|
||||
templateAdmin string = "template-admin"
|
||||
userAdmin string = "user-admin"
|
||||
auditor string = "auditor"
|
||||
agentsAccess string = "agents-access"
|
||||
// customSiteRole is a placeholder for all custom site roles.
|
||||
// This is used for what roles can assign other roles.
|
||||
// TODO: Make this more dynamic to allow other roles to grant.
|
||||
@@ -142,6 +143,7 @@ func RoleTemplateAdmin() RoleIdentifier { return RoleIdentifier{Name: templateAd
|
||||
func RoleUserAdmin() RoleIdentifier { return RoleIdentifier{Name: userAdmin} }
|
||||
func RoleMember() RoleIdentifier { return RoleIdentifier{Name: member} }
|
||||
func RoleAuditor() RoleIdentifier { return RoleIdentifier{Name: auditor} }
|
||||
func RoleAgentsAccess() RoleIdentifier { return RoleIdentifier{Name: agentsAccess} }
|
||||
|
||||
func RoleOrgAdmin() string {
|
||||
return orgAdmin
|
||||
@@ -316,7 +318,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
denyPermissions...,
|
||||
),
|
||||
User: append(
|
||||
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage, ResourceAibridgeInterception),
|
||||
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceBoundaryUsage, ResourceAibridgeInterception, ResourceChat),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Users cannot do create/update/delete on themselves, but they
|
||||
// can read their own details.
|
||||
@@ -402,6 +404,21 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
ByOrgID: map[string]OrgPermissions{},
|
||||
}.withCachedRegoValue()
|
||||
|
||||
agentsAccessRole := Role{
|
||||
Identifier: RoleAgentsAccess(),
|
||||
DisplayName: "Coder Agents User",
|
||||
Site: []Permission{},
|
||||
User: Permissions(map[string][]policy.Action{
|
||||
ResourceChat.Type: {
|
||||
policy.ActionCreate,
|
||||
policy.ActionRead,
|
||||
policy.ActionUpdate,
|
||||
policy.ActionDelete,
|
||||
},
|
||||
}),
|
||||
ByOrgID: map[string]OrgPermissions{},
|
||||
}.withCachedRegoValue()
|
||||
|
||||
builtInRoles = map[string]func(orgID uuid.UUID) Role{
|
||||
// admin grants all actions to all resources.
|
||||
owner: func(_ uuid.UUID) Role {
|
||||
@@ -428,6 +445,13 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
return userAdminRole
|
||||
},
|
||||
|
||||
// agentsAccess grants all actions on chat resources owned
|
||||
// by the user. Without this role, members cannot create
|
||||
// or interact with chats.
|
||||
agentsAccess: func(_ uuid.UUID) Role {
|
||||
return agentsAccessRole
|
||||
},
|
||||
|
||||
// orgAdmin returns a role with all actions allows in a given
|
||||
// organization scope.
|
||||
orgAdmin: func(organizationID uuid.UUID) Role {
|
||||
@@ -600,6 +624,7 @@ var assignRoles = map[string]map[string]bool{
|
||||
userAdmin: true,
|
||||
customSiteRole: true,
|
||||
customOrganizationRole: true,
|
||||
agentsAccess: true,
|
||||
},
|
||||
owner: {
|
||||
owner: true,
|
||||
@@ -615,10 +640,12 @@ var assignRoles = map[string]map[string]bool{
|
||||
userAdmin: true,
|
||||
customSiteRole: true,
|
||||
customOrganizationRole: true,
|
||||
agentsAccess: true,
|
||||
},
|
||||
userAdmin: {
|
||||
member: true,
|
||||
orgMember: true,
|
||||
member: true,
|
||||
orgMember: true,
|
||||
agentsAccess: true,
|
||||
},
|
||||
orgAdmin: {
|
||||
orgAdmin: true,
|
||||
@@ -854,13 +881,20 @@ func SiteBuiltInRoles() []Role {
|
||||
for _, roleF := range builtInRoles {
|
||||
// Must provide some non-nil uuid to filter out org roles.
|
||||
role := roleF(uuid.New())
|
||||
if !role.Identifier.IsOrgRole() {
|
||||
if !role.Identifier.IsOrgRole() && role.Identifier != RoleAgentsAccess() {
|
||||
roles = append(roles, role)
|
||||
}
|
||||
}
|
||||
return roles
|
||||
}
|
||||
|
||||
// AgentsAccessRole returns the agents-access role for use by callers
|
||||
// that need to include it conditionally (e.g. when the agents
|
||||
// experiment is enabled).
|
||||
func AgentsAccessRole() Role {
|
||||
return builtInRoles[agentsAccess](uuid.Nil)
|
||||
}
|
||||
|
||||
// ChangeRoleSet is a helper function that finds the difference of 2 sets of
|
||||
// roles. When setting a user's new roles, it is equivalent to adding and
|
||||
// removing roles. This set determines the changes, so that the appropriate
|
||||
|
||||
+87
-82
@@ -49,6 +49,11 @@ func TestBuiltInRoles(t *testing.T) {
|
||||
require.NoError(t, r.Valid(), "invalid role")
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("agents-access", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.NoError(t, rbac.AgentsAccessRole().Valid(), "invalid role")
|
||||
})
|
||||
}
|
||||
|
||||
// permissionGranted checks whether a permission list contains a
|
||||
@@ -199,6 +204,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
orgUserAdmin := authSubject{Name: "org_user_admin", Actor: rbac.Subject{ID: templateAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgUserAdmin(orgID)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
orgTemplateAdmin := authSubject{Name: "org_template_admin", Actor: rbac.Subject{ID: userAdminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgTemplateAdmin(orgID)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
orgAdminBanWorkspace := authSubject{Name: "org_admin_workspace_ban", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAdmin(orgID), rbac.ScopedRoleOrgWorkspaceCreationBan(orgID)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
agentsAccessUser := authSubject{Name: "chat_access", Actor: rbac.Subject{ID: currentUser.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleAgentsAccess()}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
setOrgNotMe := authSubjectSet{orgAdmin, orgAuditor, orgUserAdmin, orgTemplateAdmin}
|
||||
|
||||
otherOrgAdmin := authSubject{Name: "org_admin_other", Actor: rbac.Subject{ID: uuid.NewString(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgAdmin(otherOrg)}, Scope: rbac.ScopeAll}.WithCachedASTValue()}
|
||||
@@ -210,7 +216,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
// requiredSubjects are required to be asserted in each test case. This is
|
||||
// to make sure one is not forgotten.
|
||||
requiredSubjects := []authSubject{
|
||||
memberMe, owner,
|
||||
memberMe, owner, agentsAccessUser,
|
||||
orgAdmin, otherOrgAdmin, orgAuditor, orgUserAdmin, orgTemplateAdmin,
|
||||
templateAdmin, userAdmin, otherOrgAuditor, otherOrgUserAdmin, otherOrgTemplateAdmin,
|
||||
}
|
||||
@@ -233,7 +239,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionRead},
|
||||
Resource: rbac.ResourceUserObject(currentUser),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe, templateAdmin, userAdmin, orgUserAdmin, otherOrgAdmin, otherOrgUserAdmin, orgAdmin},
|
||||
true: {owner, memberMe, agentsAccessUser, templateAdmin, userAdmin, orgUserAdmin, otherOrgAdmin, otherOrgUserAdmin, orgAdmin},
|
||||
false: {
|
||||
orgTemplateAdmin, orgAuditor,
|
||||
otherOrgAuditor, otherOrgTemplateAdmin,
|
||||
@@ -246,7 +252,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceUser,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -256,7 +262,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, templateAdmin, orgTemplateAdmin, orgAdminBanWorkspace},
|
||||
false: {setOtherOrg, memberMe, userAdmin, orgAuditor, orgUserAdmin},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, userAdmin, orgAuditor, orgUserAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -266,7 +272,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, orgAdminBanWorkspace},
|
||||
false: {setOtherOrg, memberMe, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -276,7 +282,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin},
|
||||
false: {setOtherOrg, memberMe, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor, orgAdminBanWorkspace},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor, orgAdminBanWorkspace},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -286,7 +292,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspace.InOrg(orgID).WithOwner(policy.WildcardSymbol),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin},
|
||||
false: {setOtherOrg, orgUserAdmin, orgAuditor, memberMe, userAdmin, templateAdmin, orgTemplateAdmin},
|
||||
false: {setOtherOrg, orgUserAdmin, orgAuditor, memberMe, agentsAccessUser, userAdmin, templateAdmin, orgTemplateAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -296,7 +302,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -306,7 +312,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -315,7 +321,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin},
|
||||
false: {setOtherOrg, memberMe, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor, orgAdminBanWorkspace},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor, orgAdminBanWorkspace},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -324,7 +330,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, orgAdminBanWorkspace},
|
||||
false: {setOtherOrg, memberMe, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -337,7 +343,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, orgAdminBanWorkspace},
|
||||
false: {
|
||||
memberMe, setOtherOrg,
|
||||
memberMe, agentsAccessUser, setOtherOrg,
|
||||
templateAdmin, userAdmin,
|
||||
orgTemplateAdmin, orgUserAdmin, orgAuditor,
|
||||
},
|
||||
@@ -354,7 +360,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
true: {},
|
||||
false: {
|
||||
orgAdmin, owner, setOtherOrg,
|
||||
userAdmin, memberMe,
|
||||
userAdmin, memberMe, agentsAccessUser,
|
||||
templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor,
|
||||
orgAdminBanWorkspace,
|
||||
},
|
||||
@@ -366,7 +372,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceTemplate.WithID(templateID).InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, templateAdmin, orgTemplateAdmin},
|
||||
false: {setOtherOrg, orgUserAdmin, orgAuditor, memberMe, userAdmin},
|
||||
false: {setOtherOrg, orgUserAdmin, orgAuditor, memberMe, agentsAccessUser, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -375,7 +381,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceTemplate.InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAuditor, orgAdmin, templateAdmin, orgTemplateAdmin},
|
||||
false: {setOtherOrg, orgUserAdmin, memberMe, userAdmin},
|
||||
false: {setOtherOrg, orgUserAdmin, memberMe, agentsAccessUser, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -386,7 +392,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
}),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, templateAdmin, orgTemplateAdmin},
|
||||
false: {setOtherOrg, orgAuditor, orgUserAdmin, memberMe, userAdmin},
|
||||
false: {setOtherOrg, orgAuditor, orgUserAdmin, memberMe, agentsAccessUser, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -397,7 +403,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
true: {owner, templateAdmin},
|
||||
// Org template admins can only read org scoped files.
|
||||
// File scope is currently not org scoped :cry:
|
||||
false: {setOtherOrg, orgTemplateAdmin, orgAdmin, memberMe, userAdmin, orgAuditor, orgUserAdmin},
|
||||
false: {setOtherOrg, orgTemplateAdmin, orgAdmin, memberMe, agentsAccessUser, userAdmin, orgAuditor, orgUserAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -405,7 +411,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead},
|
||||
Resource: rbac.ResourceFile.WithID(fileID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe, templateAdmin},
|
||||
true: {owner, memberMe, agentsAccessUser, templateAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, userAdmin},
|
||||
},
|
||||
},
|
||||
@@ -415,7 +421,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceOrganization,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -424,7 +430,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceOrganization.WithID(orgID).InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin},
|
||||
false: {setOtherOrg, orgTemplateAdmin, orgUserAdmin, orgAuditor, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, orgTemplateAdmin, orgUserAdmin, orgAuditor, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -433,7 +439,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceOrganization.WithID(orgID).InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, templateAdmin, orgTemplateAdmin, auditor, orgAuditor, userAdmin, orgUserAdmin},
|
||||
false: {setOtherOrg, memberMe},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -442,7 +448,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceAssignOrgRole,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, userAdmin, memberMe, templateAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, userAdmin, memberMe, agentsAccessUser, templateAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -451,7 +457,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceAssignRole,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -459,7 +465,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionRead},
|
||||
Resource: rbac.ResourceAssignRole,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {setOtherOrg, setOrgNotMe, owner, memberMe, templateAdmin, userAdmin},
|
||||
true: {setOtherOrg, setOrgNotMe, owner, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
false: {},
|
||||
},
|
||||
},
|
||||
@@ -469,7 +475,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceAssignOrgRole.InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, userAdmin, orgUserAdmin},
|
||||
false: {setOtherOrg, memberMe, templateAdmin, orgTemplateAdmin, orgAuditor},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, templateAdmin, orgTemplateAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -478,7 +484,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceAssignOrgRole.InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin},
|
||||
false: {setOtherOrg, orgUserAdmin, orgTemplateAdmin, orgAuditor, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, orgUserAdmin, orgTemplateAdmin, orgAuditor, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -487,7 +493,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceAssignOrgRole.InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, orgUserAdmin, userAdmin, templateAdmin},
|
||||
false: {setOtherOrg, memberMe, orgAuditor, orgTemplateAdmin},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, orgAuditor, orgTemplateAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -495,7 +501,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionDelete, policy.ActionUpdate},
|
||||
Resource: rbac.ResourceApiKey.WithID(apiKeyID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe},
|
||||
true: {owner, memberMe, agentsAccessUser},
|
||||
false: {setOtherOrg, setOrgNotMe, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
@@ -507,7 +513,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceInboxNotification.WithID(uuid.New()).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin},
|
||||
false: {setOtherOrg, orgUserAdmin, orgTemplateAdmin, orgAuditor, templateAdmin, userAdmin, memberMe},
|
||||
false: {setOtherOrg, orgUserAdmin, orgTemplateAdmin, orgAuditor, templateAdmin, userAdmin, memberMe, agentsAccessUser},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -515,7 +521,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionReadPersonal, policy.ActionUpdatePersonal},
|
||||
Resource: rbac.ResourceUserObject(currentUser),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe, userAdmin},
|
||||
true: {owner, memberMe, agentsAccessUser, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, templateAdmin},
|
||||
},
|
||||
},
|
||||
@@ -525,7 +531,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceOrganizationMember.WithID(currentUser).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, userAdmin, orgUserAdmin},
|
||||
false: {setOtherOrg, orgTemplateAdmin, orgAuditor, memberMe, templateAdmin},
|
||||
false: {setOtherOrg, orgTemplateAdmin, orgAuditor, memberMe, agentsAccessUser, templateAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -534,7 +540,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceOrganizationMember.WithID(currentUser).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAuditor, orgAdmin, userAdmin, templateAdmin, orgUserAdmin, orgTemplateAdmin},
|
||||
false: {memberMe, setOtherOrg},
|
||||
false: {memberMe, agentsAccessUser, setOtherOrg},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -547,7 +553,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, templateAdmin, orgUserAdmin, orgTemplateAdmin, orgAuditor},
|
||||
false: {setOtherOrg, memberMe, userAdmin},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -560,7 +566,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
}),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, userAdmin, orgUserAdmin},
|
||||
false: {setOtherOrg, memberMe, templateAdmin, orgTemplateAdmin, orgAuditor},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, templateAdmin, orgTemplateAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -573,7 +579,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
}),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
false: {setOtherOrg, memberMe},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -582,7 +588,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceGroupMember.WithID(currentUser).InOrg(orgID).WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAuditor, orgAdmin, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin},
|
||||
false: {setOtherOrg, memberMe},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -591,7 +597,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceGroupMember.WithID(adminID).InOrg(orgID).WithOwner(adminID.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAuditor, orgAdmin, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin},
|
||||
false: {setOtherOrg, memberMe},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -600,7 +606,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspaceDormant.WithID(uuid.New()).InOrg(orgID).WithOwner(memberMe.Actor.ID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {orgAdmin, owner},
|
||||
false: {setOtherOrg, userAdmin, memberMe, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
false: {setOtherOrg, userAdmin, memberMe, agentsAccessUser, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -609,7 +615,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspaceDormant.WithID(uuid.New()).InOrg(orgID).WithOwner(memberMe.Actor.ID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, userAdmin, owner, templateAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, userAdmin, owner, templateAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -618,7 +624,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspace.WithID(uuid.New()).InOrg(orgID).WithOwner(memberMe.Actor.ID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin},
|
||||
false: {setOtherOrg, userAdmin, templateAdmin, memberMe, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
false: {setOtherOrg, userAdmin, templateAdmin, memberMe, agentsAccessUser, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -627,7 +633,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourcePrebuiltWorkspace.WithID(uuid.New()).InOrg(orgID).WithOwner(database.PrebuildsSystemUserID.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, templateAdmin, orgTemplateAdmin},
|
||||
false: {setOtherOrg, userAdmin, memberMe, orgUserAdmin, orgAuditor},
|
||||
false: {setOtherOrg, userAdmin, memberMe, agentsAccessUser, orgUserAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -636,7 +642,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceTask.WithID(uuid.New()).InOrg(orgID).WithOwner(memberMe.Actor.ID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin},
|
||||
false: {setOtherOrg, userAdmin, templateAdmin, memberMe, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
false: {setOtherOrg, userAdmin, templateAdmin, memberMe, agentsAccessUser, orgTemplateAdmin, orgUserAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
// Some admin style resources
|
||||
@@ -646,7 +652,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceLicense,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -655,7 +661,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceDeploymentStats,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -664,7 +670,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceDeploymentConfig,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -673,7 +679,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceDebugInfo,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -682,7 +688,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceReplicas,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -691,7 +697,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceTailnetCoordinator,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -700,7 +706,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceAuditLog,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -709,7 +715,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceProvisionerDaemon.InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, templateAdmin, orgAdmin, orgTemplateAdmin},
|
||||
false: {setOtherOrg, orgAuditor, orgUserAdmin, memberMe, userAdmin},
|
||||
false: {setOtherOrg, orgAuditor, orgUserAdmin, memberMe, agentsAccessUser, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -718,7 +724,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceProvisionerDaemon.InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, templateAdmin, orgAdmin, orgTemplateAdmin},
|
||||
false: {setOtherOrg, memberMe, userAdmin, orgAuditor, orgUserAdmin},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, userAdmin, orgAuditor, orgUserAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -727,7 +733,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceProvisionerDaemon.WithOwner(currentUser.String()).InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, templateAdmin, orgTemplateAdmin, orgAdmin},
|
||||
false: {setOtherOrg, memberMe, userAdmin, orgUserAdmin, orgAuditor},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, userAdmin, orgUserAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -736,7 +742,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceProvisionerJobs.InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgTemplateAdmin, orgAdmin},
|
||||
false: {setOtherOrg, memberMe, templateAdmin, userAdmin, orgUserAdmin, orgAuditor},
|
||||
false: {setOtherOrg, memberMe, agentsAccessUser, templateAdmin, userAdmin, orgUserAdmin, orgAuditor},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -745,7 +751,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceSystem,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -754,7 +760,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceOauth2App,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -762,7 +768,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionRead},
|
||||
Resource: rbac.ResourceOauth2App,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, setOrgNotMe, setOtherOrg, memberMe, templateAdmin, userAdmin},
|
||||
true: {owner, setOrgNotMe, setOtherOrg, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
false: {},
|
||||
},
|
||||
},
|
||||
@@ -772,7 +778,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceOauth2AppSecret,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOrgNotMe, setOtherOrg, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOrgNotMe, setOtherOrg, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -781,7 +787,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceOauth2AppCodeToken,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOrgNotMe, setOtherOrg, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOrgNotMe, setOtherOrg, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -790,7 +796,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceWorkspaceProxy,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOrgNotMe, setOtherOrg, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOrgNotMe, setOtherOrg, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -798,7 +804,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionRead},
|
||||
Resource: rbac.ResourceWorkspaceProxy,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, setOrgNotMe, setOtherOrg, memberMe, templateAdmin, userAdmin},
|
||||
true: {owner, setOrgNotMe, setOtherOrg, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
false: {},
|
||||
},
|
||||
},
|
||||
@@ -809,7 +815,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionRead, policy.ActionUpdate},
|
||||
Resource: rbac.ResourceNotificationPreference.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {memberMe, owner},
|
||||
true: {memberMe, agentsAccessUser, owner},
|
||||
false: {
|
||||
userAdmin, orgUserAdmin, templateAdmin,
|
||||
orgAuditor, orgTemplateAdmin,
|
||||
@@ -826,7 +832,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {
|
||||
memberMe, userAdmin, orgUserAdmin, templateAdmin,
|
||||
memberMe, agentsAccessUser, userAdmin, orgUserAdmin, templateAdmin,
|
||||
orgAuditor, orgTemplateAdmin,
|
||||
otherOrgAuditor, otherOrgUserAdmin, otherOrgTemplateAdmin,
|
||||
orgAdmin, otherOrgAdmin,
|
||||
@@ -840,7 +846,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {
|
||||
memberMe,
|
||||
memberMe, agentsAccessUser,
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
|
||||
@@ -858,7 +864,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {
|
||||
memberMe, templateAdmin, orgUserAdmin, userAdmin,
|
||||
memberMe, agentsAccessUser, templateAdmin, orgUserAdmin, userAdmin,
|
||||
orgAdmin, orgAuditor, orgTemplateAdmin,
|
||||
otherOrgAuditor, otherOrgUserAdmin, otherOrgTemplateAdmin,
|
||||
otherOrgAdmin,
|
||||
@@ -871,7 +877,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionDelete},
|
||||
Resource: rbac.ResourceWebpushSubscription.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe},
|
||||
true: {owner, memberMe, agentsAccessUser},
|
||||
false: {orgAdmin, otherOrgAdmin, orgAuditor, otherOrgAuditor, templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin, userAdmin, orgUserAdmin, otherOrgUserAdmin},
|
||||
},
|
||||
},
|
||||
@@ -883,7 +889,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, userAdmin, orgAdmin, otherOrgAdmin, orgUserAdmin, otherOrgUserAdmin},
|
||||
false: {
|
||||
memberMe, templateAdmin,
|
||||
memberMe, agentsAccessUser, templateAdmin,
|
||||
orgTemplateAdmin, orgAuditor,
|
||||
otherOrgAuditor, otherOrgTemplateAdmin,
|
||||
},
|
||||
@@ -896,7 +902,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin, orgAdmin, otherOrgAdmin},
|
||||
false: {
|
||||
userAdmin, memberMe,
|
||||
userAdmin, memberMe, agentsAccessUser,
|
||||
orgAuditor, orgUserAdmin,
|
||||
otherOrgAuditor, otherOrgUserAdmin,
|
||||
},
|
||||
@@ -909,7 +915,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, orgAdmin, otherOrgAdmin},
|
||||
false: {
|
||||
memberMe, userAdmin, templateAdmin,
|
||||
memberMe, agentsAccessUser, userAdmin, templateAdmin,
|
||||
orgAuditor, orgUserAdmin, orgTemplateAdmin,
|
||||
otherOrgAuditor, otherOrgUserAdmin, otherOrgTemplateAdmin,
|
||||
},
|
||||
@@ -921,7 +927,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceCryptoKey,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -932,7 +938,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
true: {owner, orgAdmin, orgUserAdmin, userAdmin},
|
||||
false: {
|
||||
otherOrgAdmin,
|
||||
memberMe, templateAdmin,
|
||||
memberMe, agentsAccessUser, templateAdmin,
|
||||
orgAuditor, orgTemplateAdmin,
|
||||
otherOrgAuditor, otherOrgUserAdmin, otherOrgTemplateAdmin,
|
||||
},
|
||||
@@ -947,7 +953,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
false: {
|
||||
orgAdmin, orgUserAdmin,
|
||||
otherOrgAdmin,
|
||||
memberMe, templateAdmin,
|
||||
memberMe, agentsAccessUser, templateAdmin,
|
||||
orgAuditor, orgTemplateAdmin,
|
||||
otherOrgAuditor, otherOrgUserAdmin, otherOrgTemplateAdmin,
|
||||
},
|
||||
@@ -960,7 +966,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {
|
||||
memberMe,
|
||||
memberMe, agentsAccessUser,
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
|
||||
@@ -975,7 +981,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {
|
||||
memberMe,
|
||||
memberMe, agentsAccessUser,
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
|
||||
@@ -989,7 +995,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Resource: rbac.ResourceConnectionLog,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
// Only the user themselves can access their own secrets — no one else.
|
||||
@@ -998,7 +1004,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
Resource: rbac.ResourceUserSecret.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {memberMe},
|
||||
true: {memberMe, agentsAccessUser},
|
||||
false: {
|
||||
owner, orgAdmin,
|
||||
otherOrgAdmin, orgAuditor, orgUserAdmin, orgTemplateAdmin,
|
||||
@@ -1014,7 +1020,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
true: {},
|
||||
false: {
|
||||
owner,
|
||||
memberMe,
|
||||
memberMe, agentsAccessUser,
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
|
||||
@@ -1028,7 +1034,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionUpdate},
|
||||
Resource: rbac.ResourceAibridgeInterception.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe},
|
||||
true: {owner, memberMe, agentsAccessUser},
|
||||
false: {
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
@@ -1045,7 +1051,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, auditor},
|
||||
false: {
|
||||
memberMe,
|
||||
memberMe, agentsAccessUser,
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
|
||||
@@ -1058,7 +1064,7 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
Resource: rbac.ResourceBoundaryUsage,
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
false: {owner, setOtherOrg, setOrgNotMe, memberMe, templateAdmin, userAdmin},
|
||||
false: {owner, setOtherOrg, setOrgNotMe, memberMe, agentsAccessUser, templateAdmin, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -1066,8 +1072,9 @@ func TestRolePermissions(t *testing.T) {
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
Resource: rbac.ResourceChat.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe},
|
||||
true: {owner, agentsAccessUser},
|
||||
false: {
|
||||
memberMe,
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
|
||||
@@ -1076,7 +1083,6 @@ func TestRolePermissions(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Build coverage set from test case definitions statically,
|
||||
// so we don't need shared mutable state during execution.
|
||||
// This allows subtests to run in parallel.
|
||||
@@ -1217,7 +1223,6 @@ func TestListRoles(t *testing.T) {
|
||||
"user-admin",
|
||||
},
|
||||
siteRoleNames)
|
||||
|
||||
orgID := uuid.New()
|
||||
orgRoles := rbac.OrganizationRoles(orgID)
|
||||
orgRoleNames := make([]string, 0, len(orgRoles))
|
||||
|
||||
+11
-1
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -43,7 +44,16 @@ func (api *API) AssignableSiteRoles(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, assignableRoles(actorRoles.Roles, rbac.SiteBuiltInRoles(), dbCustomRoles))
|
||||
siteRoles := rbac.SiteBuiltInRoles()
|
||||
// Include the agents-access role only when the agents
|
||||
// experiment is enabled or this is a dev build, matching
|
||||
// the RequireExperimentWithDevBypass gate on chat routes.
|
||||
if api.Experiments.Enabled(codersdk.ExperimentAgents) || buildinfo.IsDev() {
|
||||
siteRoles = append(siteRoles, rbac.AgentsAccessRole())
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK,
|
||||
assignableRoles(actorRoles.Roles, siteRoles, dbCustomRoles))
|
||||
}
|
||||
|
||||
// assignableOrgRoles returns all org wide roles that can be assigned.
|
||||
|
||||
@@ -1619,6 +1619,18 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
|
||||
rbacRoles = req.RBACRoles
|
||||
}
|
||||
|
||||
// When the agents experiment is enabled, auto-assign the
|
||||
// agents-access role so new users can use Coder Agents
|
||||
// without manual admin intervention. Skip this for OIDC
|
||||
// users when site role sync is enabled, because the sync
|
||||
// will overwrite roles on every login anyway — those
|
||||
// admins should use --oidc-user-role-default instead.
|
||||
if api.Experiments.Enabled(codersdk.ExperimentAgents) &&
|
||||
!(req.LoginType == database.LoginTypeOIDC && api.IDPSync.SiteRoleSyncEnabled()) &&
|
||||
!slices.Contains(rbacRoles, codersdk.RoleAgentsAccess) {
|
||||
rbacRoles = append(rbacRoles, codersdk.RoleAgentsAccess)
|
||||
}
|
||||
|
||||
var user database.User
|
||||
err := store.InTx(func(tx database.Store) error {
|
||||
orgRoles := make([]string, 0)
|
||||
|
||||
@@ -758,6 +758,35 @@ func TestPostUsers(t *testing.T) {
|
||||
assert.Equal(t, firstUser.OrganizationID, user.OrganizationIDs[0])
|
||||
})
|
||||
|
||||
// CreateWithAgentsExperiment verifies that new users
|
||||
// are auto-assigned the agents-access role when the
|
||||
// experiment is enabled. The experiment-disabled case
|
||||
// is implicitly covered by TestInitialRoles, which
|
||||
// asserts exactly [owner] with no experiment — it
|
||||
// would fail if agents-access leaked through.
|
||||
t.Run("CreateWithAgentsExperiment", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentAgents)}
|
||||
client := coderdtest.New(t, &coderdtest.Options{DeploymentValues: dv})
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{firstUser.OrganizationID},
|
||||
Email: "another@user.org",
|
||||
Username: "someone-else",
|
||||
Password: "SomeSecurePassword!",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
roles, err := client.UserRoles(ctx, user.Username)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, roles.Roles, codersdk.RoleAgentsAccess,
|
||||
"new user should have agents-access role when agents experiment is enabled")
|
||||
})
|
||||
|
||||
t.Run("CreateWithStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
auditor := audit.NewMock()
|
||||
|
||||
+120
-5
@@ -42,6 +42,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
maputil "github.com/coder/coder/v2/coderd/util/maps"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -1464,7 +1465,9 @@ func (api *API) workspaceAgentPostLogSource(rw http.ResponseWriter, r *http.Requ
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Agents
|
||||
// @Param wait query bool false "Opt in to durable reinit checks"
|
||||
// @Success 200 {object} agentsdk.ReinitializationEvent
|
||||
// @Failure 409 {object} codersdk.Response
|
||||
// @Router /workspaceagents/me/reinit [get]
|
||||
func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
|
||||
// Allow us to interrupt watch via cancel.
|
||||
@@ -1481,18 +1484,113 @@ func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to retrieve workspace from agent token", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to determine workspace from agent token"))
|
||||
return
|
||||
}
|
||||
log = log.With(slog.F("workspace_id", workspace.ID))
|
||||
|
||||
log.Info(ctx, "agent waiting for reinit instruction")
|
||||
|
||||
reinitEvents := make(chan agentsdk.ReinitializationEvent)
|
||||
cancel, err = prebuilds.NewPubsubWorkspaceClaimListener(api.Pubsub, log).ListenForWorkspaceClaims(ctx, workspace.ID, reinitEvents)
|
||||
// Subscribe to claim events BEFORE any durable checks to avoid a
|
||||
// TOCTOU race: without this, a claim could fire between the
|
||||
// IsPrebuild() check and the subscribe call, and we'd miss the
|
||||
// pubsub event entirely. By subscribing first, any event that
|
||||
// fires during the checks below is buffered in the channel.
|
||||
pubsubCh, cancelSub, err := prebuilds.NewPubsubWorkspaceClaimListener(api.Pubsub, log).ListenForWorkspaceClaims(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
log.Error(ctx, "subscribe to prebuild claimed channel", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to subscribe to prebuild claimed channel"))
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
defer cancelSub()
|
||||
|
||||
reinitEvents := pubsubCh
|
||||
|
||||
// Only perform the durable claim check when the agent opts in via
|
||||
// the "wait" query parameter. Older agents don't send the
|
||||
// "wait" query parameter and lack the duplicate-reinit guard, so
|
||||
// they would enter an infinite reinit loop if we pre-seeded the
|
||||
// channel on every connection.
|
||||
waitParam, _ := strconv.ParseBool(r.URL.Query().Get("wait"))
|
||||
if waitParam && !workspace.IsPrebuild() {
|
||||
firstBuild, err := api.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx,
|
||||
database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: 1,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to get first workspace build", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to get first workspace build"))
|
||||
return
|
||||
}
|
||||
if firstBuild.InitiatorID != database.PrebuildsSystemUserID {
|
||||
// Not a claimed prebuild — this is a regular workspace.
|
||||
// Return 409 so the agent stops reconnecting to this
|
||||
// endpoint.
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Workspace is not a prebuilt workspace waiting to be claimed.",
|
||||
Detail: "This endpoint is only for agents running in prebuilt workspaces.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// This workspace was a prebuild that got claimed. Check if
|
||||
// the claim build completed successfully before sending
|
||||
// reinit. We assume the latest build is the claim build
|
||||
// (build 2). If a third build (e.g. a restart) starts
|
||||
// between the claim and the agent's reconnection, this
|
||||
// would check that build instead. The window is extremely
|
||||
// small in practice, and a restart would trigger its own
|
||||
// reinit path.
|
||||
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to get latest workspace build", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to get latest workspace build"))
|
||||
return
|
||||
}
|
||||
job, err := api.Database.GetProvisionerJobByID(ctx, latestBuild.JobID)
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to get provisioner job", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to get provisioner job"))
|
||||
return
|
||||
}
|
||||
|
||||
if job.CompletedAt.Valid && !job.Error.Valid {
|
||||
// Claim build succeeded — cancel the pubsub
|
||||
// subscription (no longer needed) and swap in a
|
||||
// pre-seeded channel so the transmitter delivers
|
||||
// exactly one reinit event.
|
||||
cancelSub()
|
||||
seeded := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
seeded <- agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
OwnerID: workspace.OwnerID,
|
||||
}
|
||||
reinitEvents = seeded
|
||||
} else if job.CompletedAt.Valid && job.Error.Valid {
|
||||
// Claim build failed permanently. Return 409 so the
|
||||
// agent treats this as terminal and stops retrying
|
||||
// (WaitForReinitLoop exits on any 409).
|
||||
cancelSub()
|
||||
log.Warn(ctx, "claim build failed",
|
||||
slog.F("job_id", job.ID),
|
||||
slog.F("error", job.Error.String))
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Claim build failed permanently.",
|
||||
Detail: job.Error.String,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Claim build still in progress — fall through to the
|
||||
// transmitter. The pubsub subscription (set up above)
|
||||
// will deliver the event when the build completes
|
||||
// successfully. Note: FailJob does not publish a claim
|
||||
// event, so a failed in-progress build will leave the
|
||||
// agent blocking here until it disconnects and
|
||||
// reconnects (at which point the durable check above
|
||||
// handles it).
|
||||
}
|
||||
|
||||
transmitter := agentsdk.NewSSEAgentReinitTransmitter(log, rw, r)
|
||||
|
||||
@@ -1840,6 +1938,11 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
Branch: strings.TrimSpace(query.Get("git_branch")),
|
||||
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
|
||||
}
|
||||
if raw := strings.TrimSpace(query.Get("chat_id")); raw != "" {
|
||||
if parsed, err := uuid.Parse(raw); err == nil {
|
||||
gitRef.ChatID = parsed
|
||||
}
|
||||
}
|
||||
// Either match or configID must be provided!
|
||||
match := query.Get("match")
|
||||
if match == "" {
|
||||
@@ -1938,7 +2041,13 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
// context is retained even if the flow requires an out-of-band login.
|
||||
if gitRef.Branch != "" && gitRef.RemoteOrigin != "" {
|
||||
//nolint:gocritic // Chat processor context required for cross-user chat lookup
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
OwnerID: workspace.OwnerID,
|
||||
Branch: gitRef.Branch,
|
||||
Origin: gitRef.RemoteOrigin,
|
||||
ChatID: gitRef.ChatID,
|
||||
})
|
||||
}
|
||||
|
||||
var previousToken *database.ExternalAuthLink
|
||||
@@ -2087,7 +2196,13 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
|
||||
}
|
||||
// MarkStale will trigger a refresh by coderd/gitsync.
|
||||
//nolint:gocritic // Chat processor context required for cross-user chat lookup
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
OwnerID: workspace.OwnerID,
|
||||
Branch: gitRef.Branch,
|
||||
Origin: gitRef.RemoteOrigin,
|
||||
ChatID: gitRef.ChatID,
|
||||
})
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
+190
-35
@@ -2,6 +2,7 @@ package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -3278,51 +3279,205 @@ func TestAgentConnectionInfo(t *testing.T) {
|
||||
func TestReinit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
pubsubSpy := pubsubReinitSpy{
|
||||
Pubsub: ps,
|
||||
triedToSubscribe: make(chan string),
|
||||
// Helper to create the prebuilds system user's workspace (an
|
||||
// unclaimed prebuild) and return the build result. The first
|
||||
// build's InitiatorID defaults to PrebuildsSystemUserID via
|
||||
// dbfake.
|
||||
setupPrebuildWorkspace := func(t *testing.T, db database.Store, orgID uuid.UUID) dbfake.WorkspaceResponse {
|
||||
t.Helper()
|
||||
return dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: orgID,
|
||||
OwnerID: database.PrebuildsSystemUserID,
|
||||
}).WithAgent().Do()
|
||||
}
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: &pubsubSpy,
|
||||
|
||||
// Helper to simulate claiming a prebuild: change the workspace
|
||||
// owner to the real user and create a second (claim) build.
|
||||
claimPrebuild := func(t *testing.T, db database.Store, sqlDB *sql.DB, ws database.WorkspaceTable, claimerID uuid.UUID, templateVersionID uuid.UUID, complete bool) dbfake.WorkspaceResponse {
|
||||
t.Helper()
|
||||
// Change the workspace owner to the claiming user.
|
||||
_, err := sqlDB.Exec("UPDATE workspaces SET owner_id = $1 WHERE id = $2", claimerID, ws.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the in-memory workspace to reflect the new owner
|
||||
// so that dbfake uses it for the second build.
|
||||
ws.OwnerID = claimerID
|
||||
|
||||
builder := dbfake.WorkspaceBuild(t, db, ws).
|
||||
Seed(database.WorkspaceBuild{
|
||||
TemplateVersionID: templateVersionID,
|
||||
BuildNumber: 2,
|
||||
InitiatorID: claimerID,
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).
|
||||
WithAgent()
|
||||
if !complete {
|
||||
builder = builder.Starting()
|
||||
}
|
||||
return builder.Do()
|
||||
}
|
||||
|
||||
t.Run("unclaimed prebuild receives reinit via pubsub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
pubsubSpy := pubsubReinitSpy{
|
||||
Pubsub: ps,
|
||||
triedToSubscribe: make(chan string),
|
||||
}
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: &pubsubSpy,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
r := setupPrebuildWorkspace(t, db, user.OrganizationID)
|
||||
|
||||
pubsubSpy.Lock()
|
||||
pubsubSpy.expectedEvent = agentsdk.PrebuildClaimedChannel(r.Workspace.ID)
|
||||
pubsubSpy.Unlock()
|
||||
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
|
||||
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
|
||||
go func() {
|
||||
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
|
||||
assert.NoError(t, err)
|
||||
agentReinitializedCh <- reinitEvent
|
||||
}()
|
||||
|
||||
// We need to subscribe before we publish, lest we miss the
|
||||
// event.
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
testutil.TryReceive(ctx, t, pubsubSpy.triedToSubscribe)
|
||||
|
||||
// Now that we're subscribed, publish the event.
|
||||
err := prebuilds.NewPubsubWorkspaceClaimPublisher(ps).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: r.Workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
|
||||
require.NotNil(t, reinitEvent)
|
||||
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
// Verifies the durable claim check: when an agent reconnects
|
||||
// after missing the pubsub event, the handler detects that the
|
||||
// workspace was originally a prebuild (first build initiated by
|
||||
// PrebuildsSystemUserID), is now claimed (owner changed), and
|
||||
// the claim build completed, so it sends a one-shot reinit
|
||||
// event immediately.
|
||||
t.Run("claimed prebuild receives one-shot reinit on reconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pubsubSpy.Lock()
|
||||
pubsubSpy.expectedEvent = agentsdk.PrebuildClaimedChannel(r.Workspace.ID)
|
||||
pubsubSpy.Unlock()
|
||||
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
// Create an unclaimed prebuild (build 1, completed).
|
||||
r := setupPrebuildWorkspace(t, db, user.OrganizationID)
|
||||
|
||||
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
|
||||
go func() {
|
||||
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
|
||||
assert.NoError(t, err)
|
||||
agentReinitializedCh <- reinitEvent
|
||||
}()
|
||||
// Claim it: change owner + create build 2 (completed).
|
||||
claimR := claimPrebuild(t, db, sqlDB, r.Workspace, user.UserID, r.TemplateVersion.ID, true)
|
||||
|
||||
// We need to subscribe before we publish, lest we miss the event
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
testutil.TryReceive(ctx, t, pubsubSpy.triedToSubscribe)
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(claimR.AgentToken))
|
||||
|
||||
// Now that we're subscribed, publish the event
|
||||
err := prebuilds.NewPubsubWorkspaceClaimPublisher(ps).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: r.Workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
|
||||
go func() {
|
||||
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
|
||||
assert.NoError(t, err)
|
||||
agentReinitializedCh <- reinitEvent
|
||||
}()
|
||||
|
||||
// The agent should receive a reinit event immediately from
|
||||
// the durable claim check — no pubsub publish needed.
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
|
||||
require.NotNil(t, reinitEvent)
|
||||
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
|
||||
require.Equal(t, agentsdk.ReinitializeReasonPrebuildClaimed, reinitEvent.Reason)
|
||||
require.Equal(t, user.UserID, reinitEvent.OwnerID)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
|
||||
require.NotNil(t, reinitEvent)
|
||||
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
|
||||
// Verifies that when the claim build completed with an error,
|
||||
// the handler returns 409 so the agent treats it as terminal
|
||||
// and stops retrying (WaitForReinitLoop exits on any 409).
|
||||
t.Run("failed claim build returns terminal 409", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create an unclaimed prebuild (build 1, completed).
|
||||
r := setupPrebuildWorkspace(t, db, user.OrganizationID)
|
||||
|
||||
// Claim it: create build 2 as completed (so agent rows
|
||||
// exist and the token is valid for auth).
|
||||
claimR := claimPrebuild(t, db, sqlDB, r.Workspace, user.UserID, r.TemplateVersion.ID, true)
|
||||
|
||||
// Simulate a claim build failure: set an error on the
|
||||
// provisioner job. This models the case where terraform
|
||||
// apply partially succeeded (creating resources/agents)
|
||||
// but ultimately errored.
|
||||
_, err := sqlDB.Exec(
|
||||
"UPDATE provisioner_jobs SET error = 'simulated claim failure' WHERE id = $1",
|
||||
claimR.Build.JobID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(claimR.AgentToken))
|
||||
|
||||
_, err = agentClient.WaitForReinit(agentCtx)
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
// Verifies that a regular workspace (never a prebuild) gets a
|
||||
// 409 Conflict response, causing the agent's reinit loop to
|
||||
// close the channel gracefully.
|
||||
t.Run("regular workspace gets 409", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create a regular workspace (not a prebuild). The first
|
||||
// build's initiator will be the user, not the prebuilds
|
||||
// system user.
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
|
||||
// WaitForReinit should return an error wrapping a 409.
|
||||
_, err := agentClient.WaitForReinit(agentCtx)
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
type pubsubReinitSpy struct {
|
||||
|
||||
@@ -515,7 +515,11 @@ func createWorkspaceWithApps(t *testing.T, client *codersdk.Client, orgID uuid.U
|
||||
primaryAppHost, err := client.AppHost(appHostCtx)
|
||||
require.NoError(t, err)
|
||||
if primaryAppHost.Host != "" {
|
||||
rpcConn, err := agentClient.ConnectRPC(appHostCtx)
|
||||
// Fetch the manifest without marking this short-lived helper
|
||||
// connection as the workspace agent. Closing a monitored RPC
|
||||
// connection races with the real agent startup and can
|
||||
// transiently mark the agent disconnected.
|
||||
rpcConn, err := agentClient.ConnectRPCWithRole(appHostCtx, "apptest-manifest")
|
||||
require.NoError(t, err)
|
||||
aAPI := agentproto.NewDRPCAgentClient(rpcConn)
|
||||
manifest, err := aAPI.GetManifest(appHostCtx, &agentproto.GetManifestRequest{})
|
||||
|
||||
+334
-101
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -70,6 +71,13 @@ const (
|
||||
// events cached per chat for same-replica stream catch-up.
|
||||
maxDurableMessageCacheSize = 256
|
||||
|
||||
// maxConcurrentRecordingUploads caps the number of recording
|
||||
// stop-and-store operations that can run concurrently. Each
|
||||
// slot buffers up to MaxRecordingSize (100 MB) in memory, so
|
||||
// this value implicitly bounds memory to roughly
|
||||
// maxConcurrentRecordingUploads * 100 MB.
|
||||
maxConcurrentRecordingUploads = 25
|
||||
|
||||
// staleRecoveryIntervalDivisor determines how often the stale
|
||||
// recovery loop runs relative to the stale threshold. A value
|
||||
// of 5 means recovery runs at 1/5 of the stale-after duration.
|
||||
@@ -92,7 +100,7 @@ const (
|
||||
defaultSubagentInstruction = "You are running as a delegated sub-agent chat. Complete the delegated task and provide clear, concise assistant responses for the parent agent."
|
||||
)
|
||||
|
||||
var errChatHasNoWorkspaceAgent = xerrors.New("chat has no workspace agent")
|
||||
var errChatHasNoWorkspaceAgent = xerrors.New("workspace has no running agent: the workspace is likely stopped. Use the start_workspace tool to start it")
|
||||
|
||||
// Server handles background processing of pending chats.
|
||||
type Server struct {
|
||||
@@ -128,6 +136,7 @@ type Server struct {
|
||||
|
||||
usageTracker *workspacestats.UsageTracker
|
||||
clock quartz.Clock
|
||||
recordingSem chan struct{}
|
||||
|
||||
// Configuration
|
||||
pendingChatAcquireInterval time.Duration
|
||||
@@ -343,7 +352,7 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
|
||||
}
|
||||
|
||||
if !chatSnapshot.WorkspaceID.Valid {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace")
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("no workspace is associated with this chat. Use the create_workspace tool to create one")
|
||||
}
|
||||
|
||||
if chatSnapshot.AgentID.Valid {
|
||||
@@ -850,7 +859,10 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
LastModelConfigID: opts.ModelConfigID,
|
||||
Title: opts.Title,
|
||||
Mode: opts.ChatMode,
|
||||
MCPServerIDs: opts.MCPServerIDs,
|
||||
// Chats created with an initial user message start pending.
|
||||
// Waiting is reserved for idle chats with no pending work.
|
||||
Status: database.ChatStatusPending,
|
||||
MCPServerIDs: opts.MCPServerIDs,
|
||||
Labels: pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
@@ -919,10 +931,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
return xerrors.Errorf("insert initial chat messages: %w", err)
|
||||
}
|
||||
|
||||
chat, err = setChatPendingWithStore(ctx, tx, insertedChat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set chat pending: %w", err)
|
||||
}
|
||||
chat = insertedChat
|
||||
|
||||
if !chat.RootChatID.Valid && !chat.ParentChatID.Valid {
|
||||
chat.RootChatID = uuid.NullUUID{UUID: chat.ID, Valid: true}
|
||||
@@ -1244,32 +1253,90 @@ func (p *Server) EditMessage(
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ArchiveChat archives a chat and all descendants, then broadcasts a deleted event.
|
||||
// ArchiveChat archives a chat family and broadcasts deleted events for each
|
||||
// affected chat so watching clients converge without a full refetch. If the
|
||||
// target chat is pending or running, it first transitions the chat back to
|
||||
// waiting so active processing stops before the archive is broadcast.
|
||||
func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
if chat.ID == uuid.Nil {
|
||||
return xerrors.New("chat_id is required")
|
||||
}
|
||||
|
||||
if err := p.db.ArchiveChatByID(ctx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
statusChat := chat
|
||||
interrupted := false
|
||||
var archivedChats []database.Chat
|
||||
if err := p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat for archive: %w", err)
|
||||
}
|
||||
statusChat = lockedChat
|
||||
|
||||
// We do not call setChatWaiting here because it intentionally preserves
|
||||
// pending chats so queued-message promotion can win. Archiving is a
|
||||
// harder stop: both pending and running chats must transition to waiting.
|
||||
if lockedChat.Status == database.ChatStatusPending || lockedChat.Status == database.ChatStatusRunning {
|
||||
statusChat, err = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set chat waiting before archive: %w", err)
|
||||
}
|
||||
interrupted = true
|
||||
}
|
||||
|
||||
archivedChats, err = tx.ArchiveChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDeleted, nil)
|
||||
if interrupted {
|
||||
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
|
||||
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnarchiveChat unarchives a chat and publishes a created event so sidebar
|
||||
// clients are notified that the chat has reappeared.
|
||||
// UnarchiveChat unarchives a chat family and publishes created events for
|
||||
// each affected chat so watching clients see every chat that reappeared.
|
||||
func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
if chat.ID == uuid.Nil {
|
||||
return xerrors.New("chat_id is required")
|
||||
}
|
||||
|
||||
if err := p.db.UnarchiveChatByID(ctx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("unarchive chat: %w", err)
|
||||
return p.applyChatLifecycleTransition(
|
||||
ctx,
|
||||
chat.ID,
|
||||
"unarchive",
|
||||
coderdpubsub.ChatEventKindCreated,
|
||||
p.db.UnarchiveChatByID,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Server) applyChatLifecycleTransition(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
action string,
|
||||
kind coderdpubsub.ChatEventKind,
|
||||
transition func(context.Context, uuid.UUID) ([]database.Chat, error),
|
||||
) error {
|
||||
updatedChats, err := transition(ctx, chatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("%s chat: %w", action, err)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil)
|
||||
p.publishChatPubsubEvents(updatedChats, kind)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1518,17 +1585,17 @@ func (p *Server) acquireManualTitleLock(ctx context.Context, chatID uuid.UUID) e
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat for manual title regeneration: %w", err)
|
||||
}
|
||||
if isFreshManualTitleLock(lockedChat, now) {
|
||||
// Only a fresh manual lock or a chat without a real worker should
|
||||
// block title regeneration. Running chats with a real worker may
|
||||
// regenerate their title concurrently, and last write wins.
|
||||
hasRealWorker := lockedChat.Status == database.ChatStatusRunning &&
|
||||
lockedChat.WorkerID.Valid &&
|
||||
lockedChat.WorkerID.UUID != manualTitleLockWorkerID
|
||||
if lockedChat.Status == database.ChatStatusPending ||
|
||||
(lockedChat.Status == database.ChatStatusRunning && !hasRealWorker) ||
|
||||
isFreshManualTitleLock(lockedChat, now) {
|
||||
return ErrManualTitleRegenerationInProgress
|
||||
}
|
||||
|
||||
// Only write the lock marker when no real worker owns WorkerID.
|
||||
// When a real worker is running, we skip the DB lock but still
|
||||
// allow regeneration. The frontend prevents same-browser
|
||||
// double-clicks, and concurrent regeneration from different
|
||||
// replicas is harmless, last write wins.
|
||||
hasRealWorker := lockedChat.WorkerID.Valid &&
|
||||
lockedChat.WorkerID.UUID != manualTitleLockWorkerID
|
||||
if hasRealWorker {
|
||||
return nil
|
||||
}
|
||||
@@ -1591,7 +1658,7 @@ func (p *Server) RegenerateChatTitle(
|
||||
// keeping chat ownership authorization at the HTTP layer.
|
||||
//nolint:gocritic // Non-admin users need chatd-scoped config reads here.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
keys, err := p.resolveProviderAPIKeys(chatdCtx)
|
||||
keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("resolve chat providers: %w", err)
|
||||
}
|
||||
@@ -1938,33 +2005,6 @@ func (p *Server) RefreshStatus(ctx context.Context, chatID uuid.UUID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func setChatPendingWithStore(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chatID uuid.UUID,
|
||||
) (database.Chat, error) {
|
||||
chat, err := store.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("get chat: %w", err)
|
||||
}
|
||||
if chat.Status == database.ChatStatusPending {
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
updatedChat, err := store.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("set chat pending: %w", err)
|
||||
}
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database.Chat, error) {
|
||||
var updatedChat database.Chat
|
||||
err := p.db.InTx(func(tx database.Store) error {
|
||||
@@ -2340,6 +2380,7 @@ func New(cfg Config) *Server {
|
||||
chatHeartbeatInterval: chatHeartbeatInterval,
|
||||
usageTracker: cfg.UsageTracker,
|
||||
clock: clk,
|
||||
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
|
||||
wakeCh: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
@@ -3099,6 +3140,13 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C
|
||||
}
|
||||
}
|
||||
|
||||
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) {
|
||||
for _, chat := range chats {
|
||||
p.publishChatPubsubEvent(chat, kind, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
|
||||
// pubsub so that all replicas can push updates to watching clients.
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
@@ -3447,7 +3495,25 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
chatCtx, cancel := context.WithCancelCause(ctx)
|
||||
defer cancel(nil)
|
||||
|
||||
controlCancel := p.subscribeChatControl(chatCtx, chat.ID, cancel, logger)
|
||||
// Gate the control subscriber behind a channel that is closed
|
||||
// after we publish "running" status. This prevents stale
|
||||
// pubsub notifications (e.g. the "pending" notification from
|
||||
// SendMessage that triggered this processing) from
|
||||
// interrupting us before we start work. Due to async
|
||||
// PostgreSQL NOTIFY delivery, a notification published before
|
||||
// subscribeChatControl registers its queue can still arrive
|
||||
// after registration.
|
||||
controlArmed := make(chan struct{})
|
||||
gatedCancel := func(cause error) {
|
||||
select {
|
||||
case <-controlArmed:
|
||||
cancel(cause)
|
||||
default:
|
||||
logger.Debug(ctx, "ignoring control notification before armed")
|
||||
}
|
||||
}
|
||||
|
||||
controlCancel := p.subscribeChatControl(chatCtx, chat.ID, gatedCancel, logger)
|
||||
defer func() {
|
||||
if controlCancel != nil {
|
||||
controlCancel()
|
||||
@@ -3508,6 +3574,12 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
Valid: true,
|
||||
})
|
||||
|
||||
// Arm the control subscriber. Closing the channel is a
|
||||
// happens-before guarantee in the Go memory model — any
|
||||
// notification dispatched after this point will correctly
|
||||
// interrupt processing.
|
||||
close(controlArmed)
|
||||
|
||||
// Determine the final status and last error to set when we're done.
|
||||
status := database.ChatStatusWaiting
|
||||
wasInterrupted := false
|
||||
@@ -3563,9 +3635,10 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
// the worker and let the processor pick it back up.
|
||||
if latestChat.Status == database.ChatStatusPending {
|
||||
status = database.ChatStatusPending
|
||||
} else if status == database.ChatStatusWaiting {
|
||||
} else if status == database.ChatStatusWaiting && !latestChat.Archived {
|
||||
// Queued messages were already admitted through SendMessage,
|
||||
// so auto-promotion only preserves FIFO order here.
|
||||
// so auto-promotion only preserves FIFO order here. Archived
|
||||
// chats skip promotion so archiving behaves like a hard stop.
|
||||
var promoteErr error
|
||||
promotedMessage, remainingQueuedMessages, shouldPublishQueueUpdate, promoteErr = p.tryAutoPromoteQueuedMessage(cleanupCtx, tx, latestChat)
|
||||
if promoteErr != nil {
|
||||
@@ -3868,6 +3941,7 @@ func (p *Server) runChat(
|
||||
mcpCleanup func()
|
||||
workspaceMCPTools []fantasy.AgentTool
|
||||
skills []chattool.SkillMeta
|
||||
skillMetaFile = workspacesdk.DefaultSkillMetaFile
|
||||
)
|
||||
// Check if instruction files need to be (re-)persisted.
|
||||
// This happens when no context-file parts exist yet, or when
|
||||
@@ -3890,7 +3964,7 @@ func (p *Server) runChat(
|
||||
if needsInstructionPersist {
|
||||
g2.Go(func() error {
|
||||
var persistErr error
|
||||
instruction, skills, persistErr = p.persistInstructionFiles(
|
||||
instruction, skills, skillMetaFile, persistErr = p.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
modelConfig.ID,
|
||||
@@ -3917,6 +3991,9 @@ func (p *Server) runChat(
|
||||
// those messages. No workspace dial needed.
|
||||
instruction = instructionFromContextFiles(messages)
|
||||
skills = skillsFromParts(messages)
|
||||
if restored := skillMetaFileFromParts(messages); restored != "" {
|
||||
skillMetaFile = restored
|
||||
}
|
||||
}
|
||||
g2.Go(func() error {
|
||||
resolvedUserPrompt = p.resolveUserPrompt(ctx, chat.OwnerID)
|
||||
@@ -4373,7 +4450,7 @@ func (p *Server) runChat(
|
||||
workspaceCtx.chatStateMu.Unlock()
|
||||
|
||||
if !chatSnapshot.WorkspaceID.Valid {
|
||||
return uuid.Nil, xerrors.New("chat has no workspace")
|
||||
return uuid.Nil, xerrors.New("no workspace is associated with this chat. Use the create_workspace tool to create one")
|
||||
}
|
||||
|
||||
ws, err := p.db.GetWorkspaceByID(ctx, chatSnapshot.WorkspaceID.UUID)
|
||||
@@ -4407,6 +4484,7 @@ func (p *Server) runChat(
|
||||
GetSkills: func() []chattool.SkillMeta {
|
||||
return skills
|
||||
},
|
||||
SkillMetaFile: skillMetaFile,
|
||||
}
|
||||
tools = append(tools,
|
||||
chattool.ReadSkill(skillOpts),
|
||||
@@ -4730,7 +4808,7 @@ func (p *Server) resolveChatModel(
|
||||
})
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
keys, err = p.resolveProviderAPIKeys(ctx)
|
||||
keys, err = p.resolveUserProviderAPIKeys(ctx, chat.OwnerID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve provider API keys: %w", err)
|
||||
}
|
||||
@@ -4752,8 +4830,9 @@ func (p *Server) resolveChatModel(
|
||||
return model, dbConfig, keys, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveProviderAPIKeys(
|
||||
func (p *Server) resolveUserProviderAPIKeys(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
) (chatprovider.ProviderAPIKeys, error) {
|
||||
providers, err := p.configCache.EnabledProviders(ctx)
|
||||
if err != nil {
|
||||
@@ -4762,17 +4841,62 @@ func (p *Server) resolveProviderAPIKeys(
|
||||
err,
|
||||
)
|
||||
}
|
||||
dbProviders := make(
|
||||
configuredProviders := make(
|
||||
[]chatprovider.ConfiguredProvider, 0, len(providers),
|
||||
)
|
||||
for _, provider := range providers {
|
||||
dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
})
|
||||
configuredProviders = append(
|
||||
configuredProviders, chatprovider.ConfiguredProvider{
|
||||
ProviderID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserAPIKey: provider.AllowUserApiKey,
|
||||
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
},
|
||||
)
|
||||
}
|
||||
return chatprovider.MergeProviderAPIKeys(p.providerAPIKeys, dbProviders), nil
|
||||
allowAnyUserAPIKey := false
|
||||
for _, provider := range configuredProviders {
|
||||
if provider.AllowUserAPIKey {
|
||||
allowAnyUserAPIKey = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
userKeys := []chatprovider.UserProviderKey{}
|
||||
if allowAnyUserAPIKey {
|
||||
userKeyRows, err := p.db.GetUserChatProviderKeys(ctx, ownerID)
|
||||
if err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
||||
"get user chat provider keys: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
|
||||
for _, userKey := range userKeyRows {
|
||||
userKeys = append(userKeys, chatprovider.UserProviderKey{
|
||||
ChatProviderID: userKey.ChatProviderID,
|
||||
APIKey: userKey.APIKey,
|
||||
})
|
||||
}
|
||||
}
|
||||
keys, _ := chatprovider.ResolveUserProviderKeys(
|
||||
p.providerAPIKeys,
|
||||
configuredProviders,
|
||||
userKeys,
|
||||
)
|
||||
enabledProviders := make(map[string]struct{}, len(configuredProviders))
|
||||
for _, provider := range configuredProviders {
|
||||
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
enabledProviders[normalizedProvider] = struct{}{}
|
||||
}
|
||||
chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// resolveModelConfig looks up the chat's model config by its
|
||||
@@ -4859,22 +4983,22 @@ func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
||||
// skills from the workspace agent, persisting both as message
|
||||
// parts. This is called once when a workspace is first attached
|
||||
// to a chat (or when the agent changes). Returns the formatted
|
||||
// instruction string and skill index for injection into the
|
||||
// current turn's prompt.
|
||||
// instruction string, skill index, and the skill meta file name
|
||||
// for injection into the current turn's prompt.
|
||||
func (p *Server) persistInstructionFiles(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
modelConfigID uuid.UUID,
|
||||
getWorkspaceAgent func(context.Context) (database.WorkspaceAgent, error),
|
||||
getWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error),
|
||||
) (string, []chattool.SkillMeta, error) {
|
||||
) (instruction string, skills []chattool.SkillMeta, skillMetaFile string, err error) {
|
||||
if !chat.WorkspaceID.Valid || getWorkspaceAgent == nil {
|
||||
return "", nil, nil
|
||||
return "", nil, workspacesdk.DefaultSkillMetaFile, nil
|
||||
}
|
||||
|
||||
agent, err := getWorkspaceAgent(ctx)
|
||||
if err != nil {
|
||||
return "", nil, nil
|
||||
return "", nil, workspacesdk.DefaultSkillMetaFile, nil
|
||||
}
|
||||
|
||||
directory := agent.ExpandedDirectory
|
||||
@@ -4887,7 +5011,17 @@ func (p *Server) persistInstructionFiles(
|
||||
sections []instructionFileSection
|
||||
workspaceConnOK bool
|
||||
)
|
||||
if getWorkspaceConn != nil {
|
||||
|
||||
// Fetch context configuration from the agent. This tells
|
||||
// us where instruction files, skills, and MCP configs live.
|
||||
// Fall back to the pre-context-config behavior for older
|
||||
// agents that don't support the endpoint.
|
||||
agentCfg := workspacesdk.ContextConfigResponse{
|
||||
InstructionsFile: workspacesdk.DefaultInstructionsFile,
|
||||
SkillMetaFile: workspacesdk.DefaultSkillMetaFile,
|
||||
}
|
||||
|
||||
if getWorkspaceConn != nil { //nolint:nestif // Existing high-complexity block; config fallback logic adds unavoidable branches.
|
||||
instructionCtx, cancel := context.WithTimeout(ctx, homeInstructionLookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -4899,14 +5033,57 @@ func (p *Server) persistInstructionFiles(
|
||||
)
|
||||
} else {
|
||||
workspaceConnOK = true
|
||||
if content, source, truncated, readErr := readHomeInstructionFile(instructionCtx, conn); readErr != nil {
|
||||
p.logger.Debug(ctx, "failed to load home instruction file",
|
||||
slog.F("chat_id", chat.ID), slog.Error(readErr))
|
||||
} else if content != "" {
|
||||
sections = append(sections, instructionFileSection{content, source, truncated})
|
||||
|
||||
// Fetch resolved context config from agent.
|
||||
var cfgErr error
|
||||
agentCfg, cfgErr = conn.ContextConfig(instructionCtx)
|
||||
if cfgErr != nil {
|
||||
p.logger.Debug(ctx, "agent does not support context-config endpoint, using defaults",
|
||||
slog.F("chat_id", chat.ID), slog.Error(cfgErr))
|
||||
// Fall back to the pre-context-config behavior:
|
||||
// read instruction file from home dir using
|
||||
// LSRelativityHome and discover skills from the
|
||||
// working directory.
|
||||
agentCfg = workspacesdk.ContextConfigResponse{
|
||||
InstructionsFile: workspacesdk.DefaultInstructionsFile,
|
||||
SkillMetaFile: workspacesdk.DefaultSkillMetaFile,
|
||||
}
|
||||
if content, source, truncated, readErr := readHomeInstructionFile(
|
||||
instructionCtx, conn, ".coder", agentCfg.InstructionsFile,
|
||||
); readErr != nil {
|
||||
p.logger.Debug(ctx, "failed to load home instruction file",
|
||||
slog.F("chat_id", chat.ID), slog.Error(readErr))
|
||||
} else if content != "" {
|
||||
sections = append(sections, instructionFileSection{content, source, truncated})
|
||||
}
|
||||
if directory != "" {
|
||||
agentCfg.SkillsDirs = []string{path.Join(directory, ".agents/skills")}
|
||||
}
|
||||
}
|
||||
|
||||
if pwdPath := pwdInstructionFilePath(directory); pwdPath != "" {
|
||||
// Read instruction files from each configured
|
||||
// instruction directory. Track seen paths to
|
||||
// avoid reading the same file twice when the
|
||||
// user duplicates entries.
|
||||
seenDirs := make(map[string]struct{}, len(agentCfg.InstructionsDirs))
|
||||
for _, absDir := range agentCfg.InstructionsDirs {
|
||||
if _, ok := seenDirs[absDir]; ok {
|
||||
continue
|
||||
}
|
||||
seenDirs[absDir] = struct{}{}
|
||||
if content, source, truncated, readErr := readInstructionDirFile(instructionCtx, conn, absDir, agentCfg.InstructionsFile); readErr != nil {
|
||||
p.logger.Debug(ctx, "failed to load instruction file from dir",
|
||||
slog.F("chat_id", chat.ID), slog.F("dir", absDir), slog.Error(readErr))
|
||||
} else if content != "" {
|
||||
sections = append(sections, instructionFileSection{content, source, truncated})
|
||||
}
|
||||
}
|
||||
|
||||
// Also check the working directory for the
|
||||
// instruction file, unless it was already
|
||||
// covered by InstructionsDirs.
|
||||
_, pwdSeen := seenDirs[directory]
|
||||
if pwdPath := pwdInstructionFilePath(directory, agentCfg.InstructionsFile); pwdPath != "" && !pwdSeen {
|
||||
if content, source, truncated, readErr := readInstructionFile(instructionCtx, conn, pwdPath); readErr != nil {
|
||||
p.logger.Debug(ctx, "failed to load working directory instruction file",
|
||||
slog.F("chat_id", chat.ID), slog.F("directory", directory), slog.Error(readErr))
|
||||
@@ -4917,15 +5094,15 @@ func (p *Server) persistInstructionFiles(
|
||||
}
|
||||
}
|
||||
|
||||
// Discover skills from the workspace while we have a
|
||||
// connection. Errors are non-fatal — a chat without skills
|
||||
// still works, it just won't list them in the prompt.
|
||||
// Discover skills from each configured skills directory.
|
||||
// Errors are non-fatal. A chat without skills still works,
|
||||
// it just won't list them in the prompt.
|
||||
var discoveredSkills []chattool.SkillMeta
|
||||
if workspaceConnOK {
|
||||
if workspaceConnOK && len(agentCfg.SkillsDirs) > 0 {
|
||||
conn, connErr := getWorkspaceConn(ctx)
|
||||
if connErr == nil {
|
||||
var discoverErr error
|
||||
discoveredSkills, discoverErr = chattool.DiscoverSkills(ctx, conn, directory)
|
||||
discoveredSkills, discoverErr = chattool.DiscoverSkills(ctx, p.logger, conn, agentCfg.SkillsDirs, agentCfg.SkillMetaFile)
|
||||
if discoverErr != nil {
|
||||
p.logger.Debug(ctx, "failed to discover skills",
|
||||
slog.F("chat_id", chat.ID),
|
||||
@@ -4937,14 +5114,15 @@ func (p *Server) persistInstructionFiles(
|
||||
|
||||
if len(sections) == 0 {
|
||||
if !workspaceConnOK {
|
||||
return "", nil, nil
|
||||
return "", nil, agentCfg.SkillMetaFile, nil
|
||||
}
|
||||
// Persist a sentinel (plus any discovered skill parts)
|
||||
// so subsequent turns skip the workspace agent dial.
|
||||
parts := []codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: agent.ID, Valid: true},
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: agent.ID, Valid: true},
|
||||
ContextFileSkillMetaFile: agentCfg.SkillMetaFile,
|
||||
}}
|
||||
for _, s := range discoveredSkills {
|
||||
parts = append(parts, codersdk.ChatMessagePart{
|
||||
@@ -4957,7 +5135,7 @@ func (p *Server) persistInstructionFiles(
|
||||
}
|
||||
content, err := chatprompt.MarshalParts(parts)
|
||||
if err != nil {
|
||||
return "", nil, nil
|
||||
return "", nil, agentCfg.SkillMetaFile, nil
|
||||
}
|
||||
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: chat.ID,
|
||||
@@ -4970,21 +5148,38 @@ func (p *Server) persistInstructionFiles(
|
||||
chatprompt.CurrentContentVersion,
|
||||
))
|
||||
_, _ = p.db.InsertChatMessages(ctx, msgParams)
|
||||
return "", discoveredSkills, nil
|
||||
// Update the cache column: persist skills if any
|
||||
// exist, or clear to NULL so stale data from a
|
||||
// previous agent doesn't linger.
|
||||
if len(discoveredSkills) > 0 {
|
||||
skillParts := make([]codersdk.ChatMessagePart, 0, len(discoveredSkills))
|
||||
for _, s := range discoveredSkills {
|
||||
skillParts = append(skillParts, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: s.Name,
|
||||
SkillDescription: s.Description,
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: agent.ID, Valid: true},
|
||||
})
|
||||
}
|
||||
p.updateLastInjectedContext(ctx, chat.ID, skillParts)
|
||||
} else {
|
||||
p.updateLastInjectedContext(ctx, chat.ID, nil)
|
||||
}
|
||||
return "", discoveredSkills, agentCfg.SkillMetaFile, nil
|
||||
}
|
||||
|
||||
// Build context-file parts (one per instruction file) and
|
||||
// skill parts (one per discovered skill).
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(sections)+len(discoveredSkills))
|
||||
for _, s := range sections {
|
||||
parts = append(parts, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: s.source,
|
||||
ContextFileContent: s.content,
|
||||
ContextFileTruncated: s.truncated,
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: agent.ID, Valid: true},
|
||||
ContextFileOS: agent.OperatingSystem,
|
||||
ContextFileDirectory: directory,
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: s.source,
|
||||
ContextFileContent: s.content,
|
||||
ContextFileTruncated: s.truncated,
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: agent.ID, Valid: true},
|
||||
ContextFileOS: agent.OperatingSystem,
|
||||
ContextFileDirectory: directory,
|
||||
ContextFileSkillMetaFile: agentCfg.SkillMetaFile,
|
||||
})
|
||||
}
|
||||
for _, s := range discoveredSkills {
|
||||
@@ -4999,7 +5194,7 @@ func (p *Server) persistInstructionFiles(
|
||||
|
||||
content, err := chatprompt.MarshalParts(parts)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("marshal context-file parts: %w", err)
|
||||
return "", nil, agentCfg.SkillMetaFile, xerrors.Errorf("marshal context-file parts: %w", err)
|
||||
}
|
||||
|
||||
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
@@ -5013,13 +5208,51 @@ func (p *Server) persistInstructionFiles(
|
||||
chatprompt.CurrentContentVersion,
|
||||
))
|
||||
if _, err := p.db.InsertChatMessages(ctx, msgParams); err != nil {
|
||||
return "", nil, xerrors.Errorf("persist instruction files: %w", err)
|
||||
return "", nil, agentCfg.SkillMetaFile, xerrors.Errorf("persist instruction files: %w", err)
|
||||
}
|
||||
// Build stripped copies for the cache column so internal
|
||||
// fields (full file content, OS, directory, skill paths)
|
||||
// are never persisted or returned to API clients.
|
||||
stripped := make([]codersdk.ChatMessagePart, len(parts))
|
||||
copy(stripped, parts)
|
||||
for i := range stripped {
|
||||
stripped[i].StripInternal()
|
||||
}
|
||||
p.updateLastInjectedContext(ctx, chat.ID, stripped)
|
||||
|
||||
// Return the formatted instruction text and discovered skills
|
||||
// so the caller can inject them into this turn's prompt (since
|
||||
// the prompt was built before we persisted).
|
||||
return formatSystemInstructions(agent.OperatingSystem, directory, sections), discoveredSkills, nil
|
||||
return formatSystemInstructions(agent.OperatingSystem, directory, sections), discoveredSkills, agentCfg.SkillMetaFile, nil
|
||||
}
|
||||
|
||||
// updateLastInjectedContext persists the injected context
|
||||
// parts (AGENTS.md files and skills) on the chat row so they
|
||||
// are directly queryable without scanning messages. This is
|
||||
// best-effort — a failure here is logged but does not block
|
||||
// the turn.
|
||||
func (p *Server) updateLastInjectedContext(ctx context.Context, chatID uuid.UUID, parts []codersdk.ChatMessagePart) {
|
||||
param := pqtype.NullRawMessage{Valid: false}
|
||||
if parts != nil {
|
||||
raw, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to marshal injected context",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
param = pqtype.NullRawMessage{RawMessage: raw, Valid: true}
|
||||
}
|
||||
if _, err := p.db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
p.logger.Warn(ctx, "failed to update injected context",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveUserCompactionThreshold looks up the user's per-model
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -62,8 +63,8 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
}
|
||||
modelConfig := database.ChatModelConfig{
|
||||
ID: modelConfigID,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-haiku-4-5",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: 8192,
|
||||
}
|
||||
updatedChat := chat
|
||||
@@ -85,9 +86,9 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer cancelSub()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
require.Equal(t, "claude-haiku-4-5", req.Model)
|
||||
return chattest.AnthropicNonStreamingResponse(wantTitle)
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
require.Equal(t, "gpt-4o-mini", req.Model)
|
||||
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
|
||||
})
|
||||
|
||||
server := &Server{
|
||||
@@ -99,9 +100,10 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "anthropic",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: serverURL,
|
||||
Provider: "openai",
|
||||
CentralApiKeyEnabled: true,
|
||||
APIKey: "test-key",
|
||||
BaseUrl: serverURL,
|
||||
}}, nil)
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
||||
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
||||
@@ -221,8 +223,8 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
lockedChat.StartedAt = sql.NullTime{Time: time.Now(), Valid: true}
|
||||
modelConfig := database.ChatModelConfig{
|
||||
ID: modelConfigID,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-haiku-4-5",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: 8192,
|
||||
}
|
||||
updatedChat := lockedChat
|
||||
@@ -247,9 +249,9 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
require.NoError(t, err)
|
||||
defer cancelSub()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
require.Equal(t, "claude-haiku-4-5", req.Model)
|
||||
return chattest.AnthropicNonStreamingResponse(wantTitle)
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
require.Equal(t, "gpt-4o-mini", req.Model)
|
||||
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
|
||||
})
|
||||
|
||||
server := &Server{
|
||||
@@ -261,9 +263,10 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "anthropic",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: serverURL,
|
||||
Provider: "openai",
|
||||
CentralApiKeyEnabled: true,
|
||||
APIKey: "test-key",
|
||||
BaseUrl: serverURL,
|
||||
}}, nil)
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
||||
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
||||
@@ -378,6 +381,87 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
ownerID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
configCache: newChatConfigCache(
|
||||
context.Background(),
|
||||
db,
|
||||
quartz.NewReal(),
|
||||
),
|
||||
providerAPIKeys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "openai-deployment-key",
|
||||
Anthropic: "anthropic-deployment-key",
|
||||
ByProvider: map[string]string{
|
||||
"openai": "openai-deployment-key",
|
||||
"anthropic": "anthropic-deployment-key",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
"openai": "https://openai.example.com",
|
||||
"anthropic": "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "anthropic",
|
||||
CentralApiKeyEnabled: true,
|
||||
AllowCentralApiKeyFallback: true,
|
||||
}}, nil)
|
||||
|
||||
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, keys.OpenAI)
|
||||
require.Empty(t, keys.APIKey("openai"))
|
||||
require.Empty(t, keys.BaseURL("openai"))
|
||||
require.Equal(t, "anthropic-deployment-key", keys.Anthropic)
|
||||
require.Equal(t, "anthropic-deployment-key", keys.APIKey("anthropic"))
|
||||
require.Equal(t, "https://anthropic.example.com", keys.BaseURL("anthropic"))
|
||||
require.Equal(t, map[string]string{"anthropic": "anthropic-deployment-key"}, keys.ByProvider)
|
||||
require.Equal(t, map[string]string{"anthropic": "https://anthropic.example.com"}, keys.BaseURLByProvider)
|
||||
}
|
||||
|
||||
func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
ownerID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
configCache: newChatConfigCache(
|
||||
context.Background(),
|
||||
db,
|
||||
quartz.NewReal(),
|
||||
),
|
||||
providerAPIKeys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "openai-deployment-key",
|
||||
ByProvider: map[string]string{
|
||||
"openai": "openai-deployment-key",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "openai",
|
||||
CentralApiKeyEnabled: true,
|
||||
}}, nil)
|
||||
|
||||
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "openai-deployment-key", keys.OpenAI)
|
||||
require.Equal(t, "openai-deployment-key", keys.APIKey("openai"))
|
||||
}
|
||||
|
||||
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -484,18 +568,50 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
|
||||
agentID,
|
||||
).Return(workspaceAgent, nil).Times(1)
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
||||
gomock.Cond(func(x any) bool {
|
||||
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
||||
if !ok || arg.ID != chat.ID {
|
||||
return false
|
||||
}
|
||||
if !arg.LastInjectedContext.Valid {
|
||||
return false
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
// Expect at least one context-file part for the
|
||||
// working-directory AGENTS.md, with internal fields
|
||||
// stripped (no content, OS, or directory).
|
||||
for _, p := range parts {
|
||||
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFilePath != "" {
|
||||
return p.ContextFileContent == "" &&
|
||||
p.ContextFileOS == "" &&
|
||||
p.ContextFileDirectory == ""
|
||||
}
|
||||
}
|
||||
return false
|
||||
}),
|
||||
).Return(database.Chat{}, nil).Times(1)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
|
||||
InstructionsDirs: []string{"/home/coder/.coder"},
|
||||
InstructionsFile: "AGENTS.md",
|
||||
SkillsDirs: []string{"/home/coder/project/.agents/skills"},
|
||||
SkillMetaFile: "SKILL.md",
|
||||
}, nil).AnyTimes()
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{},
|
||||
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
|
||||
).AnyTimes()
|
||||
conn.EXPECT().ReadFile(gomock.Any(),
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
int64(maxInstructionFileBytes+1)).Return(
|
||||
io.NopCloser(strings.NewReader("# Project instructions")),
|
||||
"",
|
||||
nil,
|
||||
@@ -520,7 +636,7 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
instruction, _, err := server.persistInstructionFiles(
|
||||
instruction, _, _, err := server.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
uuid.New(),
|
||||
@@ -532,6 +648,157 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
|
||||
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
||||
}
|
||||
|
||||
func TestPersistInstructionFilesFallbackOnOlderAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// When the agent doesn't support the context-config endpoint
|
||||
// (returns an error), the fallback path should:
|
||||
// 1. Read instruction files from ~/.coder using LSRelativityHome
|
||||
// 2. Discover skills from the working directory
|
||||
// 3. Return the default skill meta file name
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
workspaceAgent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
OperatingSystem: "linux",
|
||||
Directory: "/home/coder/project",
|
||||
ExpandedDirectory: "/home/coder/project",
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(
|
||||
gomock.Any(),
|
||||
agentID,
|
||||
).Return(workspaceAgent, nil).Times(1)
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Any()).Return(database.Chat{}, nil).AnyTimes()
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
|
||||
// ContextConfig returns error — simulating an older agent.
|
||||
conn.EXPECT().ContextConfig(gomock.Any()).Return(
|
||||
workspacesdk.ContextConfigResponse{},
|
||||
codersdk.NewTestError(404, "GET", "/api/v0/context-config"),
|
||||
).Times(1)
|
||||
|
||||
// Fallback: readHomeInstructionFile uses LSRelativityHome
|
||||
// to read from ~/.coder directory.
|
||||
conn.EXPECT().LS(gomock.Any(), "",
|
||||
gomock.Cond(func(x any) bool {
|
||||
req, ok := x.(workspacesdk.LSRequest)
|
||||
return ok && req.Relativity == workspacesdk.LSRelativityHome &&
|
||||
len(req.Path) == 1 && req.Path[0] == ".coder"
|
||||
}),
|
||||
).Return(workspacesdk.LSResponse{
|
||||
Contents: []workspacesdk.LSFile{{
|
||||
Name: "AGENTS.md",
|
||||
AbsolutePathString: "/home/user/.coder/AGENTS.md",
|
||||
IsDir: false,
|
||||
}},
|
||||
}, nil).Times(1)
|
||||
|
||||
// ReadFile for the home instruction file.
|
||||
conn.EXPECT().ReadFile(gomock.Any(),
|
||||
"/home/user/.coder/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
io.NopCloser(strings.NewReader("# Home instructions")),
|
||||
"",
|
||||
nil,
|
||||
).Times(1)
|
||||
|
||||
// Working directory instruction file: 404.
|
||||
conn.EXPECT().ReadFile(gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
nil, "",
|
||||
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
|
||||
).Times(1)
|
||||
|
||||
// Skills directory: fallback constructs path from working dir.
|
||||
conn.EXPECT().LS(gomock.Any(), "",
|
||||
gomock.Cond(func(x any) bool {
|
||||
req, ok := x.(workspacesdk.LSRequest)
|
||||
return ok && req.Relativity == workspacesdk.LSRelativityRoot &&
|
||||
len(req.Path) == 1 && req.Path[0] == "/home/coder/project/.agents/skills"
|
||||
}),
|
||||
).Return(workspacesdk.LSResponse{
|
||||
Contents: []workspacesdk.LSFile{{
|
||||
Name: "fallback-skill",
|
||||
AbsolutePathString: "/home/coder/project/.agents/skills/fallback-skill",
|
||||
IsDir: true,
|
||||
}},
|
||||
}, nil).Times(1)
|
||||
|
||||
skillContent := "---\nname: fallback-skill\ndescription: Discovered via fallback\n---\nBody"
|
||||
conn.EXPECT().ReadFile(gomock.Any(),
|
||||
"/home/coder/project/.agents/skills/fallback-skill/SKILL.md",
|
||||
int64(0),
|
||||
int64(64*1024+1),
|
||||
).Return(
|
||||
io.NopCloser(strings.NewReader(skillContent)),
|
||||
"",
|
||||
nil,
|
||||
).Times(1)
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
return conn, func() {}, nil
|
||||
},
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
instruction, skills, skillMeta, err := server.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
uuid.New(),
|
||||
workspaceCtx.getWorkspaceAgent,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
// Instruction should contain the home instruction file content.
|
||||
require.Contains(t, instruction, "Home instructions")
|
||||
// OS and directory metadata should be present.
|
||||
require.Contains(t, instruction, "Operating System: linux")
|
||||
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
||||
// Skills should be discovered from the working directory.
|
||||
require.Len(t, skills, 1)
|
||||
require.Equal(t, "fallback-skill", skills[0].Name)
|
||||
// Skill meta file should be the default.
|
||||
require.Equal(t, workspacesdk.DefaultSkillMetaFile, skillMeta)
|
||||
}
|
||||
|
||||
func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -551,7 +818,7 @@ func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
}
|
||||
|
||||
instruction, _, err := server.persistInstructionFiles(
|
||||
instruction, _, _, err := server.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
uuid.New(),
|
||||
@@ -569,6 +836,261 @@ func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing
|
||||
require.Empty(t, instruction)
|
||||
}
|
||||
|
||||
func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
workspaceAgent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
OperatingSystem: "linux",
|
||||
Directory: "/home/coder/project",
|
||||
ExpandedDirectory: "/home/coder/project",
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(
|
||||
gomock.Any(),
|
||||
agentID,
|
||||
).Return(workspaceAgent, nil).Times(1)
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
||||
gomock.Cond(func(x any) bool {
|
||||
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
||||
if !ok || arg.ID != chat.ID {
|
||||
return false
|
||||
}
|
||||
if !arg.LastInjectedContext.Valid {
|
||||
return false
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
// The sentinel path should persist only skill parts
|
||||
// with ContextFileAgentID set.
|
||||
for _, p := range parts {
|
||||
if p.Type == codersdk.ChatMessagePartTypeSkill &&
|
||||
p.SkillName == "my-skill" &&
|
||||
p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}),
|
||||
).Return(database.Chat{}, nil).Times(1)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
|
||||
InstructionsDirs: []string{"/home/coder/.coder"},
|
||||
InstructionsFile: "AGENTS.md",
|
||||
SkillsDirs: []string{"/home/coder/project/.agents/skills"},
|
||||
SkillMetaFile: "SKILL.md",
|
||||
}, nil).AnyTimes()
|
||||
|
||||
// Instruction dir LS (.coder directory): return 404 so no
|
||||
// instruction file is found from the configured dir.
|
||||
conn.EXPECT().LS(gomock.Any(), "",
|
||||
gomock.Cond(func(x any) bool {
|
||||
req, ok := x.(workspacesdk.LSRequest)
|
||||
return ok && req.Relativity == workspacesdk.LSRelativityRoot &&
|
||||
len(req.Path) == 1 && req.Path[0] == "/home/coder/.coder"
|
||||
}),
|
||||
).Return(
|
||||
workspacesdk.LSResponse{},
|
||||
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
|
||||
).Times(1)
|
||||
|
||||
// Pwd AGENTS.md: return 404 so no working-directory
|
||||
// instruction file is found either.
|
||||
conn.EXPECT().ReadFile(gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
nil, "",
|
||||
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
|
||||
).Times(1)
|
||||
|
||||
// Skills LS (.agents/skills directory): return one skill
|
||||
// directory so DiscoverSkills finds it.
|
||||
conn.EXPECT().LS(gomock.Any(), "",
|
||||
gomock.Cond(func(x any) bool {
|
||||
req, ok := x.(workspacesdk.LSRequest)
|
||||
return ok && req.Relativity == workspacesdk.LSRelativityRoot &&
|
||||
len(req.Path) == 1 && req.Path[0] == "/home/coder/project/.agents/skills"
|
||||
}),
|
||||
).Return(workspacesdk.LSResponse{
|
||||
Contents: []workspacesdk.LSFile{{
|
||||
Name: "my-skill",
|
||||
AbsolutePathString: "/home/coder/project/.agents/skills/my-skill",
|
||||
IsDir: true,
|
||||
}},
|
||||
}, nil).Times(1)
|
||||
|
||||
// Skills SKILL.md ReadFile: return valid frontmatter.
|
||||
skillContent := "---\nname: my-skill\ndescription: A test skill\n---\nSkill body"
|
||||
conn.EXPECT().ReadFile(gomock.Any(),
|
||||
"/home/coder/project/.agents/skills/my-skill/SKILL.md",
|
||||
int64(0),
|
||||
int64(64*1024+1),
|
||||
).Return(
|
||||
io.NopCloser(strings.NewReader(skillContent)),
|
||||
"",
|
||||
nil,
|
||||
).Times(1)
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
return conn, func() {}, nil
|
||||
},
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
instruction, skills, _, err := server.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
uuid.New(),
|
||||
workspaceCtx.getWorkspaceAgent,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
// Sentinel path returns empty instruction string.
|
||||
require.Empty(t, instruction)
|
||||
// Skills are still discovered and returned.
|
||||
require.Len(t, skills, 1)
|
||||
require.Equal(t, "my-skill", skills[0].Name)
|
||||
}
|
||||
|
||||
func TestPersistInstructionFilesSentinelNoSkillsClearsColumn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
workspaceAgent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
OperatingSystem: "linux",
|
||||
Directory: "/home/coder/project",
|
||||
ExpandedDirectory: "/home/coder/project",
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(
|
||||
gomock.Any(),
|
||||
agentID,
|
||||
).Return(workspaceAgent, nil).Times(1)
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
||||
gomock.Cond(func(x any) bool {
|
||||
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
||||
if !ok || arg.ID != chat.ID {
|
||||
return false
|
||||
}
|
||||
// No skills discovered, so the column should be
|
||||
// cleared to NULL.
|
||||
return !arg.LastInjectedContext.Valid
|
||||
}),
|
||||
).Return(database.Chat{}, nil).Times(1)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
|
||||
InstructionsDirs: []string{"/home/coder/.coder"},
|
||||
InstructionsFile: "AGENTS.md",
|
||||
SkillsDirs: []string{"/home/coder/project/.agents/skills"},
|
||||
SkillMetaFile: "SKILL.md",
|
||||
}, nil).AnyTimes()
|
||||
|
||||
// All LS calls return 404: no .coder directory and no
|
||||
// .agents/skills directory.
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{},
|
||||
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
|
||||
).AnyTimes()
|
||||
|
||||
// Pwd AGENTS.md: return 404.
|
||||
conn.EXPECT().ReadFile(gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
nil, "",
|
||||
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
|
||||
).Times(1)
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
return conn, func() {}, nil
|
||||
},
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
instruction, skills, _, err := server.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
uuid.New(),
|
||||
workspaceCtx.getWorkspaceAgent,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
// Sentinel path: empty instruction, no skills.
|
||||
require.Empty(t, instruction)
|
||||
require.Empty(t, skills)
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1751,3 +2273,95 @@ func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage
|
||||
Content: pqtype.NullRawMessage{RawMessage: raw, Valid: true},
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessChat_IgnoresStaleControlNotification verifies that
|
||||
// processChat is not interrupted by a "pending" notification
|
||||
// published before processing begins. This is the race that caused
|
||||
// TestOpenAIReasoningWithWebSearchRoundTripStoreFalse to flake:
|
||||
// SendMessage publishes "pending" via PostgreSQL NOTIFY, and due
|
||||
// to async delivery the notification can arrive at the control
|
||||
// subscriber after it registers but before the processor publishes
|
||||
// "running".
|
||||
func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
ps := dbpubsub.NewInMemory()
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
chatID := uuid.New()
|
||||
workerID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
pubsub: ps,
|
||||
clock: clock,
|
||||
workerID: workerID,
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
configCache: newChatConfigCache(ctx, db, clock),
|
||||
}
|
||||
|
||||
// Publish a stale "pending" notification on the control channel
|
||||
// BEFORE processChat subscribes. In production this is the
|
||||
// notification from SendMessage that triggered the processing.
|
||||
staleNotify, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
|
||||
Status: string(database.ChatStatusPending),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), staleNotify)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Track which status processChat writes during cleanup.
|
||||
var finalStatus database.ChatStatus
|
||||
cleanupDone := make(chan struct{})
|
||||
|
||||
// The deferred cleanup in processChat runs a transaction.
|
||||
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
||||
return fn(db)
|
||||
},
|
||||
)
|
||||
db.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(
|
||||
database.Chat{ID: chatID, Status: database.ChatStatusRunning, WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}}, nil,
|
||||
)
|
||||
db.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, params database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
finalStatus = params.Status
|
||||
close(cleanupDone)
|
||||
return database.Chat{ID: chatID, Status: params.Status}, nil
|
||||
},
|
||||
)
|
||||
|
||||
// resolveChatModel fails immediately — that's fine, we only
|
||||
// need processChat to get past initialization without being
|
||||
// interrupted by the stale notification.
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return(
|
||||
database.ChatModelConfig{}, xerrors.New("no model configured"),
|
||||
).AnyTimes()
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
|
||||
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
|
||||
).AnyTimes()
|
||||
db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes()
|
||||
|
||||
chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New()}
|
||||
go server.processChat(ctx, chat)
|
||||
|
||||
select {
|
||||
case <-cleanupDone:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("processChat did not complete")
|
||||
}
|
||||
|
||||
// If the stale notification interrupted us, status would be
|
||||
// "waiting" (the ErrInterrupted path). Since the gate blocked
|
||||
// it, processChat reached runChat, which failed on model
|
||||
// resolution → status is "error".
|
||||
require.Equal(t, database.ChatStatusError, finalStatus,
|
||||
"processChat should have reached runChat (error), not been interrupted (waiting)")
|
||||
}
|
||||
|
||||
+567
-61
@@ -297,6 +297,180 @@ func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
}
|
||||
|
||||
func TestArchiveChatMovesPendingChatToWaiting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "archive-pending",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = replica.ArchiveChat(ctx, chat)
|
||||
require.NoError(t, err)
|
||||
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
require.False(t, fromDB.StartedAt.Valid)
|
||||
require.False(t, fromDB.HeartbeatAt.Valid)
|
||||
require.True(t, fromDB.Archived)
|
||||
require.Zero(t, fromDB.PinOrder)
|
||||
}
|
||||
|
||||
func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
streamStarted := make(chan struct{})
|
||||
streamCanceled := make(chan struct{})
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
chunks := make(chan chattest.OpenAIChunk, 1)
|
||||
go func() {
|
||||
defer close(chunks)
|
||||
chunks <- chattest.OpenAITextChunks("partial")[0]
|
||||
select {
|
||||
case <-streamStarted:
|
||||
default:
|
||||
close(streamStarted)
|
||||
}
|
||||
<-req.Context().Done()
|
||||
select {
|
||||
case <-streamCanceled:
|
||||
default:
|
||||
close(streamCanceled)
|
||||
}
|
||||
}()
|
||||
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
||||
})
|
||||
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "archive-interrupt",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
select {
|
||||
case <-streamStarted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
_, events, cancel, ok := server.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer cancel()
|
||||
|
||||
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, queuedResult.Queued)
|
||||
require.NotNil(t, queuedResult.QueuedMessage)
|
||||
|
||||
err = server.ArchiveChat(ctx, chat)
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
select {
|
||||
case <-streamCanceled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
gotWaitingStatus := false
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
for {
|
||||
select {
|
||||
case ev := <-events:
|
||||
if ev.Type == codersdk.ChatStreamEventTypeStatus &&
|
||||
ev.Status != nil &&
|
||||
ev.Status.Status == codersdk.ChatStatusWaiting {
|
||||
gotWaitingStatus = true
|
||||
return true
|
||||
}
|
||||
default:
|
||||
return gotWaitingStatus
|
||||
}
|
||||
}
|
||||
}, testutil.IntervalFast)
|
||||
require.True(t, gotWaitingStatus, "expected a waiting status event after archive")
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Archived &&
|
||||
fromDB.Status == database.ChatStatusWaiting &&
|
||||
!fromDB.WorkerID.Valid &&
|
||||
!fromDB.StartedAt.Valid &&
|
||||
!fromDB.HeartbeatAt.Valid
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queuedMessages, 1)
|
||||
require.Equal(t, queuedResult.QueuedMessage.ID, queuedMessages[0].ID)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
userMessages := 0
|
||||
for _, msg := range messages {
|
||||
if msg.Role == database.ChatMessageRoleUser {
|
||||
userMessages++
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
|
||||
}
|
||||
|
||||
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -473,6 +647,11 @@ func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// CreateChat calls signalWake which triggers processOnce in
|
||||
// the background. Wait for that processing to finish so it
|
||||
// doesn't race with the manual status update below.
|
||||
waitForChatProcessed(ctx, t, db, chat.ID, replica)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
@@ -729,6 +908,7 @@ func TestCreateChatRejectsWhenUsageLimitReached(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
existingChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
Title: "existing-limit-chat",
|
||||
LastModelConfigID: model.ID,
|
||||
@@ -817,6 +997,11 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// CreateChat calls signalWake which triggers processOnce in
|
||||
// the background. Wait for that processing to finish so it
|
||||
// doesn't race with the manual status update below.
|
||||
waitForChatProcessed(ctx, t, db, chat.ID, replica)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
@@ -879,10 +1064,6 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatMessageRoleUser, result.PromotedMessage.Role)
|
||||
|
||||
chat, err = db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, chat.Status)
|
||||
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, queued)
|
||||
@@ -1018,6 +1199,7 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
require.NotNil(t, laterQueuedResult.QueuedMessage)
|
||||
|
||||
spendChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
@@ -1268,6 +1450,7 @@ func TestRecoverStaleChatsPeriodically(t *testing.T) {
|
||||
// to running with a heartbeat in the past.
|
||||
deadWorkerID := uuid.New()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
Title: "stale-recovery-periodic",
|
||||
LastModelConfigID: model.ID,
|
||||
@@ -1313,6 +1496,7 @@ func TestRecoverStaleChatsPeriodically(t *testing.T) {
|
||||
// This tests the periodic recovery, not just the startup one.
|
||||
deadWorkerID2 := uuid.New()
|
||||
chat2, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
Title: "stale-recovery-periodic-2",
|
||||
LastModelConfigID: model.ID,
|
||||
@@ -1351,6 +1535,7 @@ func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) {
|
||||
// heartbeat (well beyond the stale threshold).
|
||||
deadReplicaID := uuid.New()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
Title: "orphaned-chat",
|
||||
LastModelConfigID: model.ID,
|
||||
@@ -1393,6 +1578,7 @@ func TestWaitingChatsAreNotRecoveredAsStale(t *testing.T) {
|
||||
// Create a chat in waiting status — this should NOT be touched
|
||||
// by stale recovery.
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
Title: "waiting-chat",
|
||||
LastModelConfigID: model.ID,
|
||||
@@ -1435,6 +1621,7 @@ func TestUpdateChatStatusPersistsLastError(t *testing.T) {
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
Title: "error-persisted",
|
||||
LastModelConfigID: model.ID,
|
||||
@@ -1566,6 +1753,10 @@ func TestPersistToolResultWithBinaryData(t *testing.T) {
|
||||
mockConn.EXPECT().
|
||||
SetExtraHeaders(gomock.Any()).
|
||||
AnyTimes()
|
||||
mockConn.EXPECT().
|
||||
ContextConfig(gomock.Any()).
|
||||
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).
|
||||
AnyTimes()
|
||||
mockConn.EXPECT().
|
||||
ListMCPTools(gomock.Any()).
|
||||
Return(workspacesdk.ListMCPToolsResponse{}, nil).
|
||||
@@ -1709,13 +1900,9 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
||||
// subscribing, so the snapshot captures the final state.
|
||||
// The wake signal may trigger processOnce which will fail
|
||||
// (no LLM configured) and set the chat to error status.
|
||||
// Poll until the chat leaves pending status, then wait for
|
||||
// the goroutine to finish.
|
||||
require.Eventually(t, func() bool {
|
||||
c, err := db.GetChatByID(ctx, chat.ID)
|
||||
return err == nil && c.Status != database.ChatStatusPending
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
chatd.WaitUntilIdleForTest(replica)
|
||||
// Poll until the chat reaches a terminal state (not pending
|
||||
// and not running), then wait for the goroutine to finish.
|
||||
waitForChatProcessed(ctx, t, db, chat.ID, replica)
|
||||
|
||||
snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
@@ -2299,7 +2486,7 @@ func TestStoppedWorkspaceWithPersistedAgentBindingDoesNotBlockChat(t *testing.T)
|
||||
if message.Role != "tool" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(message.Content, "chat has no workspace agent") {
|
||||
if strings.Contains(message.Content, "workspace has no running agent") {
|
||||
foundUnavailableToolResult = true
|
||||
break
|
||||
}
|
||||
@@ -2312,8 +2499,8 @@ func TestStoppedWorkspaceWithPersistedAgentBindingDoesNotBlockChat(t *testing.T)
|
||||
}
|
||||
errMsg, _ := toolResult["error"].(string)
|
||||
outputMsg, _ := toolResult["output"].(string)
|
||||
if strings.Contains(errMsg, "chat has no workspace agent") ||
|
||||
strings.Contains(outputMsg, "chat has no workspace agent") {
|
||||
if strings.Contains(errMsg, "workspace has no running agent") ||
|
||||
strings.Contains(outputMsg, "workspace has no running agent") {
|
||||
foundUnavailableToolResult = true
|
||||
break
|
||||
}
|
||||
@@ -2346,7 +2533,7 @@ func TestStoppedWorkspaceWithPersistedAgentBindingDoesNotBlockChat(t *testing.T)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[0].Type)
|
||||
require.Equal(t, "execute", parts[0].ToolName)
|
||||
require.True(t, parts[0].IsError)
|
||||
require.Contains(t, string(parts[0].Result), "chat has no workspace agent")
|
||||
require.Contains(t, string(parts[0].Result), "workspace has no running agent")
|
||||
}
|
||||
|
||||
func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) {
|
||||
@@ -2598,6 +2785,39 @@ func TestHeartbeatNoWorkspaceNoBump(t *testing.T) {
|
||||
require.Equal(t, 0, count, "expected no workspaces to be flushed when chat has no workspace")
|
||||
}
|
||||
|
||||
// waitForChatProcessed waits for a wake-triggered processOnce to
|
||||
// fully complete for the given chat. It polls until the chat leaves
|
||||
// both pending and running states (meaning processChat has finished
|
||||
// its cleanup and updated the DB), then calls WaitUntilIdleForTest.
|
||||
//
|
||||
// Waiting for a terminal state (not just "not pending") avoids a
|
||||
// WaitGroup Add/Wait race: AcquireChats changes the DB status to
|
||||
// running before processOnce calls inflight.Add(1). If we only
|
||||
// waited for status != pending, we could call Wait() while Add(1)
|
||||
// hasn't happened yet.
|
||||
func waitForChatProcessed(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
server *chatd.Server,
|
||||
) {
|
||||
t.Helper()
|
||||
require.Eventually(t, func() bool {
|
||||
c, err := db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// Wait until the chat reaches a terminal state — neither
|
||||
// pending (waiting to be acquired) nor running (being
|
||||
// processed). This guarantees that inflight.Add(1) has
|
||||
// already been called by processOnce.
|
||||
return c.Status != database.ChatStatusPending &&
|
||||
c.Status != database.ChatStatusRunning
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
chatd.WaitUntilIdleForTest(server)
|
||||
}
|
||||
|
||||
func newTestServer(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
@@ -2673,12 +2893,13 @@ func seedChatDependenciesWithProvider(
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: provider,
|
||||
APIKey: "test-key",
|
||||
BaseUrl: baseURL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: provider,
|
||||
DisplayName: provider,
|
||||
APIKey: "test-key",
|
||||
BaseUrl: baseURL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
@@ -2697,6 +2918,102 @@ func seedChatDependenciesWithProvider(
|
||||
return user, model
|
||||
}
|
||||
|
||||
func seedChatDependenciesWithProviderPolicy(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
provider string,
|
||||
baseURL string,
|
||||
apiKey string,
|
||||
centralAPIKeyEnabled bool,
|
||||
allowUserAPIKey bool,
|
||||
allowCentralAPIKeyFallback bool,
|
||||
) (database.User, database.ChatProvider, database.ChatModelConfig) {
|
||||
t.Helper()
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
providerConfig, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: provider,
|
||||
APIKey: apiKey,
|
||||
BaseUrl: baseURL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
AllowUserApiKey: allowUserAPIKey,
|
||||
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: provider,
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return user, providerConfig, model
|
||||
}
|
||||
|
||||
func waitForTerminalChatStatusEvent(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
events <-chan codersdk.ChatStreamEvent,
|
||||
) codersdk.ChatStatus {
|
||||
t.Helper()
|
||||
|
||||
var terminalStatus codersdk.ChatStatus
|
||||
testutil.Eventually(ctx, t, func(context.Context) bool {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
||||
continue
|
||||
}
|
||||
if event.Status.Status == codersdk.ChatStatusWaiting || event.Status.Status == codersdk.ChatStatusError {
|
||||
terminalStatus = event.Status.Status
|
||||
return true
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
return terminalStatus
|
||||
}
|
||||
|
||||
func waitForTerminalChat(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
var chatResult database.Chat
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
got, err := db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
return chatResult
|
||||
}
|
||||
|
||||
// seedWorkspaceWithAgent creates a full workspace chain with a connected
|
||||
// agent. This is the common setup needed by tests that exercise tool
|
||||
// execution against a workspace.
|
||||
@@ -2753,12 +3070,15 @@ func setOpenAIProviderBaseURL(
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: provider.ApiKeyKeyID,
|
||||
Enabled: provider.Enabled,
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: provider.ApiKeyKeyID,
|
||||
Enabled: provider.Enabled,
|
||||
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: provider.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -3332,12 +3652,13 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
|
||||
// Add an Anthropic provider pointing to our mock server.
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-anthropic-key",
|
||||
BaseUrl: anthropicSrv.URL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-anthropic-key",
|
||||
BaseUrl: anthropicSrv.URL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3367,6 +3688,10 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
mockConn.EXPECT().
|
||||
SetExtraHeaders(gomock.Any()).
|
||||
AnyTimes()
|
||||
mockConn.EXPECT().
|
||||
ContextConfig(gomock.Any()).
|
||||
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).
|
||||
AnyTimes()
|
||||
mockConn.EXPECT().
|
||||
LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.LSResponse{}, xerrors.New("not found")).
|
||||
@@ -3617,6 +3942,135 @@ func TestInterruptChatPersistsPartialResponse(t *testing.T) {
|
||||
"partial assistant response should contain the streamed text")
|
||||
}
|
||||
|
||||
func TestProcessChat_UserProviderKey_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
const userAPIKey = "user-test-key"
|
||||
|
||||
var authHeadersMu sync.Mutex
|
||||
authHeaders := make([]string, 0, 1)
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
authHeadersMu.Lock()
|
||||
authHeaders = append(authHeaders, req.Header.Get("Authorization"))
|
||||
authHeadersMu.Unlock()
|
||||
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("user provider key success")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("hello from the saved user key")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, provider, model := seedChatDependenciesWithProviderPolicy(
|
||||
ctx,
|
||||
t,
|
||||
db,
|
||||
"openai-compat",
|
||||
openAIURL,
|
||||
"",
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
_, err := db.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: user.ID,
|
||||
ChatProviderID: provider.ID,
|
||||
APIKey: userAPIKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
creator := newTestServer(t, db, ps, uuid.New())
|
||||
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "user-provider-key-success",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("say hello"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
_ = newActiveTestServer(t, db, ps)
|
||||
|
||||
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus)
|
||||
|
||||
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
||||
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
||||
require.False(t, chatResult.LastError.Valid)
|
||||
|
||||
authHeadersMu.Lock()
|
||||
recordedAuthHeaders := append([]string(nil), authHeaders...)
|
||||
authHeadersMu.Unlock()
|
||||
require.Contains(t, recordedAuthHeaders, "Bearer "+userAPIKey)
|
||||
}
|
||||
|
||||
func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
var llmCalls atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
llmCalls.Add(1)
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("unexpected non-streaming request")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("unexpected streaming request")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, _, model := seedChatDependenciesWithProviderPolicy(
|
||||
ctx,
|
||||
t,
|
||||
db,
|
||||
"openai-compat",
|
||||
openAIURL,
|
||||
"",
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
|
||||
creator := newTestServer(t, db, ps, uuid.New())
|
||||
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "user-provider-key-missing",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("say hello"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
_ = newActiveTestServer(t, db, ps)
|
||||
|
||||
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
|
||||
require.Equal(t, codersdk.ChatStatusError, terminalStatus)
|
||||
|
||||
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
||||
require.Equal(t, database.ChatStatusError, chatResult.Status)
|
||||
require.True(t, chatResult.LastError.Valid, "LastError should be set")
|
||||
require.NotEmpty(t, chatResult.LastError.String)
|
||||
require.NotContains(t, chatResult.LastError.String, "panicked")
|
||||
require.NotEqual(t, database.ChatStatusRunning, chatResult.Status)
|
||||
require.Zero(t, llmCalls.Load(), "missing user key should fail before any LLM request")
|
||||
}
|
||||
|
||||
func TestProcessChatPanicRecovery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -3800,6 +4254,8 @@ func TestMCPServerToolInvocation(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
||||
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
||||
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
||||
mockConn.EXPECT().ListMCPTools(gomock.Any()).
|
||||
Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes()
|
||||
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
@@ -3815,11 +4271,10 @@ func TestMCPServerToolInvocation(t *testing.T) {
|
||||
})
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "mcp-tool-test",
|
||||
ModelConfigID: model.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
||||
OwnerID: user.ID,
|
||||
Title: "mcp-tool-test", ModelConfigID: model.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Echo something via MCP."),
|
||||
},
|
||||
@@ -4043,13 +4498,14 @@ func TestMCPServerOAuth2TokenRefresh(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
||||
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
||||
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
||||
mockConn.EXPECT().ListMCPTools(gomock.Any()).
|
||||
Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes()
|
||||
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.LSResponse{}, nil).AnyTimes()
|
||||
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
|
||||
|
||||
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
||||
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
require.Equal(t, dbAgent.ID, agentID)
|
||||
@@ -4217,21 +4673,46 @@ func TestChatTemplateAllowlistEnforcement(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
|
||||
// Set up a mock OpenAI server. The first streaming call triggers
|
||||
// list_templates; subsequent calls respond with text.
|
||||
// Declare templates before the handler so the closure can
|
||||
// reference their IDs when building tool-call arguments.
|
||||
var tplAllowed, tplBlocked database.Template
|
||||
|
||||
// Set up a mock OpenAI server that chains tool calls:
|
||||
// 1. list_templates
|
||||
// 2. read_template (blocked template — should fail)
|
||||
// 3. read_template (allowed template — should succeed)
|
||||
// 4. create_workspace (blocked template — should fail)
|
||||
// 5. text response
|
||||
var callCount atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
if callCount.Add(1) == 1 {
|
||||
switch callCount.Add(1) {
|
||||
case 1:
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk("list_templates", `{}`),
|
||||
)
|
||||
case 2:
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk("read_template",
|
||||
fmt.Sprintf(`{"template_id":%q}`, tplBlocked.ID.String())),
|
||||
)
|
||||
case 3:
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk("read_template",
|
||||
fmt.Sprintf(`{"template_id":%q}`, tplAllowed.ID.String())),
|
||||
)
|
||||
case 4:
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk("create_workspace",
|
||||
fmt.Sprintf(`{"template_id":%q}`, tplBlocked.ID.String())),
|
||||
)
|
||||
default:
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("Done testing.")...,
|
||||
)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("Here are the templates.")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
|
||||
@@ -4242,12 +4723,12 @@ func TestChatTemplateAllowlistEnforcement(t *testing.T) {
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
tplAllowed := dbgen.Template(t, db, database.Template{
|
||||
tplAllowed = dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
Name: "allowed-template",
|
||||
})
|
||||
tplBlocked := dbgen.Template(t, db, database.Template{
|
||||
tplBlocked = dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
Name: "blocked-template",
|
||||
@@ -4259,14 +4740,27 @@ func TestChatTemplateAllowlistEnforcement(t *testing.T) {
|
||||
err = db.UpsertChatTemplateAllowlist(dbauthz.AsSystemRestricted(ctx), string(allowlistJSON))
|
||||
require.NoError(t, err)
|
||||
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
||||
// Provide a CreateWorkspace function so the tool reaches
|
||||
// the allowlist check instead of bailing with "not
|
||||
// configured". If the allowlist is enforced correctly
|
||||
// this function will never be called.
|
||||
cfg.CreateWorkspace = func(
|
||||
_ context.Context,
|
||||
_ uuid.UUID,
|
||||
_ codersdk.CreateWorkspaceRequest,
|
||||
) (codersdk.Workspace, error) {
|
||||
t.Error("CreateWorkspace should not be called for a blocked template")
|
||||
return codersdk.Workspace{}, xerrors.New("unexpected call")
|
||||
}
|
||||
})
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "allowlist-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("List templates"),
|
||||
codersdk.ChatMessageText("Test allowlist enforcement"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -4286,9 +4780,11 @@ func TestChatTemplateAllowlistEnforcement(t *testing.T) {
|
||||
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
|
||||
}
|
||||
|
||||
// Find the list_templates tool result in the persisted messages.
|
||||
var toolResult string
|
||||
// Collect all tool results keyed by tool name. Each tool may
|
||||
// have been called more than once, so we store a slice.
|
||||
var toolResults map[string][]string
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
toolResults = map[string][]string{}
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
@@ -4305,23 +4801,33 @@ func TestChatTemplateAllowlistEnforcement(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolResult &&
|
||||
part.ToolName == "list_templates" {
|
||||
toolResult = string(part.Result)
|
||||
return true
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolResult {
|
||||
toolResults[part.ToolName] = append(
|
||||
toolResults[part.ToolName], string(part.Result))
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
// We expect results from all four tool calls.
|
||||
return len(toolResults["list_templates"]) >= 1 &&
|
||||
len(toolResults["read_template"]) >= 2 &&
|
||||
len(toolResults["create_workspace"]) >= 1
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
require.NotEmpty(t, toolResult, "list_templates tool result should be persisted")
|
||||
|
||||
// The result should contain only the allowed template.
|
||||
require.Contains(t, toolResult, tplAllowed.ID.String(),
|
||||
// list_templates: only the allowed template should appear.
|
||||
require.Contains(t, toolResults["list_templates"][0], tplAllowed.ID.String(),
|
||||
"allowed template should appear in list_templates result")
|
||||
require.NotContains(t, toolResult, tplBlocked.ID.String(),
|
||||
require.NotContains(t, toolResults["list_templates"][0], tplBlocked.ID.String(),
|
||||
"blocked template should NOT appear in list_templates result")
|
||||
|
||||
// read_template: blocked ID → error, allowed ID → success.
|
||||
require.Contains(t, toolResults["read_template"][0], "not found",
|
||||
"read_template for blocked template should return not-found error")
|
||||
require.Contains(t, toolResults["read_template"][1], tplAllowed.ID.String(),
|
||||
"read_template for allowed template should return template details")
|
||||
|
||||
// create_workspace: blocked ID → rejected.
|
||||
require.Contains(t, toolResults["create_workspace"][0], "not available",
|
||||
"create_workspace for blocked template should be rejected")
|
||||
}
|
||||
|
||||
// TestSignalWakeImmediateAcquisition verifies that CreateChat triggers
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
@@ -1317,11 +1318,25 @@ func isContextLimitKey(key string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(normalized, "context") &&
|
||||
(strings.Contains(normalized, "limit") ||
|
||||
strings.Contains(normalized, "window") ||
|
||||
strings.Contains(normalized, "length") ||
|
||||
strings.HasPrefix(normalized, "max"))
|
||||
words := metadataKeyWords(key)
|
||||
if !slices.Contains(words, "context") {
|
||||
return false
|
||||
}
|
||||
|
||||
if slices.Contains(words, "limit") {
|
||||
return true
|
||||
}
|
||||
|
||||
if slices.Contains(words, "window") {
|
||||
return slices.Contains(words, "size") || slices.Contains(words, "max")
|
||||
}
|
||||
|
||||
if slices.Contains(words, "length") {
|
||||
return slices.Contains(words, "max")
|
||||
}
|
||||
|
||||
return (slices.Contains(words, "token") || slices.Contains(words, "tokens")) &&
|
||||
(slices.Contains(words, "max") || slices.Contains(words, "limit"))
|
||||
}
|
||||
|
||||
func normalizeMetadataKey(key string) string {
|
||||
@@ -1342,6 +1357,40 @@ func normalizeMetadataKey(key string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func metadataKeyWords(key string) []string {
|
||||
words := make([]string, 0, 4)
|
||||
var current strings.Builder
|
||||
|
||||
flush := func() {
|
||||
if current.Len() == 0 {
|
||||
return
|
||||
}
|
||||
words = append(words, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
|
||||
var prev rune
|
||||
var hasPrev bool
|
||||
for _, r := range key {
|
||||
if !unicode.IsLetter(r) {
|
||||
flush()
|
||||
hasPrev = false
|
||||
continue
|
||||
}
|
||||
|
||||
if hasPrev && unicode.IsUpper(r) && unicode.IsLower(prev) {
|
||||
flush()
|
||||
}
|
||||
|
||||
_, _ = current.WriteRune(unicode.ToLower(r))
|
||||
prev = r
|
||||
hasPrev = true
|
||||
}
|
||||
|
||||
flush()
|
||||
return words
|
||||
}
|
||||
|
||||
func numericContextLimitValue(value any) (int64, bool) {
|
||||
switch typed := value.(type) {
|
||||
case int64:
|
||||
|
||||
@@ -53,6 +53,29 @@ func TestNormalizeMetadataKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataKeyWords(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
key string
|
||||
want []string
|
||||
}{
|
||||
{"max_context_tokens", []string{"max", "context", "tokens"}},
|
||||
{"maxContextTokens", []string{"max", "context", "tokens"}},
|
||||
{"MAX_CONTEXT", []string{"max", "context"}},
|
||||
{"ContextWindow", []string{"context", "window"}},
|
||||
{"context2limit", []string{"context", "limit"}},
|
||||
{"", []string{}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := metadataKeyWords(tt.key)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsContextLimitKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -60,7 +83,6 @@ func TestIsContextLimitKey(t *testing.T) {
|
||||
name string
|
||||
key string
|
||||
want bool
|
||||
skip bool
|
||||
}{ // Exact matches after normalization.
|
||||
{name: "context_limit", key: "context_limit", want: true},
|
||||
{name: "context_window", key: "context_window", want: true},
|
||||
@@ -75,18 +97,22 @@ func TestIsContextLimitKey(t *testing.T) {
|
||||
{name: "Context-Window mixed case", key: "Context-Window", want: true},
|
||||
{name: "MAX_CONTEXT_TOKENS screaming", key: "MAX_CONTEXT_TOKENS", want: true},
|
||||
{name: "contextLimit camelCase", key: "contextLimit", want: true},
|
||||
{name: "modelContextLimit camelCase", key: "modelContextLimit", want: true},
|
||||
|
||||
// Fallback heuristic: contains "context" + limit/window/length.
|
||||
// Fallback heuristic: tokenized "context" + limit/window/length.
|
||||
{name: "model_context_limit", key: "model_context_limit", want: true},
|
||||
{name: "context_window_size", key: "context_window_size", want: true},
|
||||
{name: "context_length_max", key: "context_length_max", want: true},
|
||||
|
||||
// Fallback heuristic: starts with "max" + contains "context".
|
||||
// BUG(isContextLimitKey): "max_context_version" matches
|
||||
// because it contains "context" and starts with "max",
|
||||
// but a version field is not a context limit.
|
||||
// TODO: Fix the heuristic and remove this skip.
|
||||
{name: "max_context_version false positive", key: "max_context_version", want: false, skip: true}, // Non-matching keys.
|
||||
// Exact matches remain valid after separator stripping.
|
||||
{name: "max_context_", key: "max_context_", want: true},
|
||||
{name: "max_context_limit", key: "max_context_limit", want: true},
|
||||
|
||||
// Non-matching keys should not be treated as context limits.
|
||||
{name: "max_context_version false positive", key: "max_context_version", want: false},
|
||||
{name: "context_tokens_used false positive", key: "context_tokens_used", want: false},
|
||||
{name: "context_length_used false positive", key: "context_length_used", want: false},
|
||||
{name: "context_window_used false positive", key: "context_window_used", want: false},
|
||||
{name: "context_id no limit keyword", key: "context_id", want: false},
|
||||
{name: "empty string", key: "", want: false},
|
||||
{name: "unrelated key", key: "model_name", want: false},
|
||||
@@ -97,9 +123,6 @@ func TestIsContextLimitKey(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if tt.skip {
|
||||
t.Skip("known bug: isContextLimitKey false positive")
|
||||
}
|
||||
got := isContextLimitKey(tt.key)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
@@ -385,6 +408,19 @@ func TestExtractContextLimit(t *testing.T) {
|
||||
assert.False(t, result.Valid)
|
||||
})
|
||||
|
||||
t.Run("ContextUsageCountersIgnored", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
metadata := fantasy.ProviderMetadata{
|
||||
"openai": &testProviderData{
|
||||
data: map[string]any{
|
||||
"context_tokens_used": float64(64000),
|
||||
},
|
||||
},
|
||||
}
|
||||
result := extractContextLimit(metadata)
|
||||
assert.False(t, result.Valid)
|
||||
})
|
||||
|
||||
t.Run("NilMetadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := extractContextLimit(nil)
|
||||
|
||||
@@ -1459,11 +1459,12 @@ func TestNulEscapeRoundTrip(t *testing.T) {
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "openai",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "openai",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1482,6 +1483,7 @@ func TestNulEscapeRoundTrip(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "nul-roundtrip-test",
|
||||
@@ -1942,11 +1944,12 @@ func TestMediaToolResultRoundTrip(t *testing.T) {
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "anthropic",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "anthropic",
|
||||
DisplayName: "anthropic",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1978,6 +1981,7 @@ func TestMediaToolResultRoundTrip(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
chat, chatErr := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "media-roundtrip-" + callID,
|
||||
|
||||
@@ -81,11 +81,28 @@ type ProviderAPIKeys struct {
|
||||
BaseURLByProvider map[string]string
|
||||
}
|
||||
|
||||
// UserProviderKey is a user-supplied API key for a specific provider.
|
||||
type UserProviderKey struct {
|
||||
ChatProviderID uuid.UUID
|
||||
APIKey string
|
||||
}
|
||||
|
||||
// ProviderAvailability describes whether a provider has a usable
|
||||
// API key and, if not, why.
|
||||
type ProviderAvailability struct {
|
||||
Available bool
|
||||
UnavailableReason codersdk.ChatModelProviderUnavailableReason
|
||||
}
|
||||
|
||||
// ConfiguredProvider is an enabled provider loaded from database config.
|
||||
type ConfiguredProvider struct {
|
||||
Provider string
|
||||
APIKey string
|
||||
BaseURL string
|
||||
ProviderID uuid.UUID
|
||||
Provider string
|
||||
APIKey string
|
||||
BaseURL string
|
||||
CentralAPIKeyEnabled bool
|
||||
AllowUserAPIKey bool
|
||||
AllowCentralAPIKeyFallback bool
|
||||
}
|
||||
|
||||
// ConfiguredModel is an enabled model loaded from database config.
|
||||
@@ -189,21 +206,146 @@ func MergeProviderAPIKeys(fallback ProviderAPIKeys, providers []ConfiguredProvid
|
||||
return merged
|
||||
}
|
||||
|
||||
type ModelCatalog struct {
|
||||
keys ProviderAPIKeys
|
||||
// ResolveUserProviderKeys computes effective API keys and per-provider
|
||||
// availability for a given user. It considers the provider's credential
|
||||
// policy flags alongside central (DB/deployment) keys and the user's
|
||||
// personal keys.
|
||||
func ResolveUserProviderKeys(
|
||||
fallback ProviderAPIKeys,
|
||||
providers []ConfiguredProvider,
|
||||
userKeys []UserProviderKey,
|
||||
) (ProviderAPIKeys, map[string]ProviderAvailability) {
|
||||
merged := ProviderAPIKeys{
|
||||
OpenAI: strings.TrimSpace(fallback.OpenAI),
|
||||
Anthropic: strings.TrimSpace(fallback.Anthropic),
|
||||
ByProvider: map[string]string{},
|
||||
BaseURLByProvider: map[string]string{},
|
||||
}
|
||||
for provider, apiKey := range fallback.ByProvider {
|
||||
normalizedProvider := NormalizeProvider(provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
if key := strings.TrimSpace(apiKey); key != "" {
|
||||
merged.ByProvider[normalizedProvider] = key
|
||||
}
|
||||
}
|
||||
for provider, baseURL := range fallback.BaseURLByProvider {
|
||||
normalizedProvider := NormalizeProvider(provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
if url := strings.TrimSpace(baseURL); url != "" {
|
||||
merged.BaseURLByProvider[normalizedProvider] = url
|
||||
}
|
||||
}
|
||||
if merged.OpenAI != "" {
|
||||
merged.ByProvider[fantasyopenai.Name] = merged.OpenAI
|
||||
}
|
||||
if merged.Anthropic != "" {
|
||||
merged.ByProvider[fantasyanthropic.Name] = merged.Anthropic
|
||||
}
|
||||
|
||||
userKeyByProviderID := make(map[uuid.UUID]string, len(userKeys))
|
||||
for _, userKey := range userKeys {
|
||||
if userKey.ChatProviderID == uuid.Nil {
|
||||
continue
|
||||
}
|
||||
if key := strings.TrimSpace(userKey.APIKey); key != "" {
|
||||
userKeyByProviderID[userKey.ChatProviderID] = key
|
||||
}
|
||||
}
|
||||
|
||||
availabilityByProvider := make(map[string]ProviderAvailability, len(providers))
|
||||
for _, provider := range providers {
|
||||
normalizedProvider := NormalizeProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if url := strings.TrimSpace(provider.BaseURL); url != "" {
|
||||
merged.BaseURLByProvider[normalizedProvider] = url
|
||||
}
|
||||
|
||||
var userKey string
|
||||
if provider.ProviderID != uuid.Nil {
|
||||
userKey = userKeyByProviderID[provider.ProviderID]
|
||||
}
|
||||
|
||||
var centralKey string
|
||||
if provider.CentralAPIKeyEnabled {
|
||||
if key := strings.TrimSpace(provider.APIKey); key != "" {
|
||||
centralKey = key
|
||||
} else {
|
||||
centralKey = fallback.APIKey(normalizedProvider)
|
||||
}
|
||||
}
|
||||
|
||||
resolved := ProviderAvailability{}
|
||||
chosenKey := ""
|
||||
switch {
|
||||
case provider.AllowUserAPIKey && userKey != "":
|
||||
chosenKey = userKey
|
||||
resolved.Available = true
|
||||
case centralKey != "":
|
||||
if !provider.AllowUserAPIKey || provider.AllowCentralAPIKeyFallback {
|
||||
chosenKey = centralKey
|
||||
resolved.Available = true
|
||||
} else {
|
||||
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired
|
||||
}
|
||||
case provider.AllowUserAPIKey && provider.AllowCentralAPIKeyFallback && provider.CentralAPIKeyEnabled:
|
||||
// When users can add their own key, a missing central fallback key is
|
||||
// still something the user can remedy.
|
||||
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired
|
||||
case provider.AllowUserAPIKey:
|
||||
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired
|
||||
default:
|
||||
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
|
||||
}
|
||||
|
||||
setResolvedProviderAPIKey(&merged, normalizedProvider, chosenKey)
|
||||
availabilityByProvider[normalizedProvider] = resolved
|
||||
}
|
||||
|
||||
return merged, availabilityByProvider
|
||||
}
|
||||
|
||||
func NewModelCatalog(keys ProviderAPIKeys) *ModelCatalog {
|
||||
return &ModelCatalog{
|
||||
keys: keys,
|
||||
func setResolvedProviderAPIKey(keys *ProviderAPIKeys, provider string, apiKey string) {
|
||||
normalizedProvider := NormalizeProvider(provider)
|
||||
if normalizedProvider == "" {
|
||||
return
|
||||
}
|
||||
if keys.ByProvider == nil {
|
||||
keys.ByProvider = map[string]string{}
|
||||
}
|
||||
|
||||
delete(keys.ByProvider, normalizedProvider)
|
||||
trimmedKey := strings.TrimSpace(apiKey)
|
||||
switch normalizedProvider {
|
||||
case fantasyopenai.Name:
|
||||
keys.OpenAI = trimmedKey
|
||||
case fantasyanthropic.Name:
|
||||
keys.Anthropic = trimmedKey
|
||||
}
|
||||
if trimmedKey != "" {
|
||||
keys.ByProvider[normalizedProvider] = trimmedKey
|
||||
}
|
||||
}
|
||||
|
||||
type ModelCatalog struct{}
|
||||
|
||||
func NewModelCatalog() *ModelCatalog {
|
||||
return &ModelCatalog{}
|
||||
}
|
||||
|
||||
// ListConfiguredModels returns a model catalog from enabled DB-backed model
|
||||
// configs. The second return value reports whether DB-backed models were used.
|
||||
func (c *ModelCatalog) ListConfiguredModels(
|
||||
func (*ModelCatalog) ListConfiguredModels(
|
||||
configuredProviders []ConfiguredProvider,
|
||||
configuredModels []ConfiguredModel,
|
||||
availabilityByProvider map[string]ProviderAvailability,
|
||||
enabledProviders map[string]struct{},
|
||||
) (codersdk.ChatModelsResponse, bool) {
|
||||
if len(configuredModels) == 0 {
|
||||
return codersdk.ChatModelsResponse{}, false
|
||||
@@ -247,11 +389,14 @@ func (c *ModelCatalog) ListConfiguredModels(
|
||||
return codersdk.ChatModelsResponse{}, false
|
||||
}
|
||||
|
||||
keys := MergeProviderAPIKeys(c.keys, configuredProviders)
|
||||
response := codersdk.ChatModelsResponse{
|
||||
Providers: make([]codersdk.ChatModelProvider, 0, len(providers)),
|
||||
}
|
||||
for _, provider := range providers {
|
||||
if _, ok := enabledProviders[provider]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
models := modelsByProvider[provider]
|
||||
sortChatModels(models)
|
||||
|
||||
@@ -259,11 +404,14 @@ func (c *ModelCatalog) ListConfiguredModels(
|
||||
Provider: provider,
|
||||
Models: models,
|
||||
}
|
||||
if keys.APIKey(provider) == "" {
|
||||
if avail, ok := availabilityByProvider[provider]; ok {
|
||||
result.Available = avail.Available
|
||||
if !avail.Available {
|
||||
result.UnavailableReason = avail.UnavailableReason
|
||||
}
|
||||
} else {
|
||||
result.Available = false
|
||||
result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
|
||||
} else {
|
||||
result.Available = true
|
||||
}
|
||||
|
||||
response.Providers = append(response.Providers, result)
|
||||
@@ -273,25 +421,32 @@ func (c *ModelCatalog) ListConfiguredModels(
|
||||
}
|
||||
|
||||
// ListConfiguredProviderAvailability returns provider availability derived from
|
||||
// deployment/env keys merged with enabled DB provider keys.
|
||||
func (c *ModelCatalog) ListConfiguredProviderAvailability(
|
||||
configuredProviders []ConfiguredProvider,
|
||||
// the policy-aware availability map for enabled providers.
|
||||
func (*ModelCatalog) ListConfiguredProviderAvailability(
|
||||
availabilityByProvider map[string]ProviderAvailability,
|
||||
enabledProviders map[string]struct{},
|
||||
) codersdk.ChatModelsResponse {
|
||||
keys := MergeProviderAPIKeys(c.keys, configuredProviders)
|
||||
response := codersdk.ChatModelsResponse{
|
||||
Providers: make([]codersdk.ChatModelProvider, 0, len(supportedProviderNames)),
|
||||
}
|
||||
|
||||
for _, provider := range supportedProviderNames {
|
||||
if _, ok := enabledProviders[provider]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
result := codersdk.ChatModelProvider{
|
||||
Provider: provider,
|
||||
Models: []codersdk.ChatModel{},
|
||||
}
|
||||
if keys.APIKey(provider) == "" {
|
||||
if avail, ok := availabilityByProvider[provider]; ok {
|
||||
result.Available = avail.Available
|
||||
if !avail.Available {
|
||||
result.UnavailableReason = avail.UnavailableReason
|
||||
}
|
||||
} else {
|
||||
result.Available = false
|
||||
result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
|
||||
} else {
|
||||
result.Available = true
|
||||
}
|
||||
|
||||
response.Providers = append(response.Providers, result)
|
||||
@@ -300,6 +455,27 @@ func (c *ModelCatalog) ListConfiguredProviderAvailability(
|
||||
return response
|
||||
}
|
||||
|
||||
// PruneDisabledProviderKeys removes entries from keys that do not
|
||||
// belong to an enabled provider. It clears ByProvider and
|
||||
// BaseURLByProvider entries for disabled providers and zeroes the
|
||||
// legacy OpenAI and Anthropic fields when those providers are not
|
||||
// enabled.
|
||||
func PruneDisabledProviderKeys(keys *ProviderAPIKeys, enabledProviders map[string]struct{}) {
|
||||
for provider := range keys.ByProvider {
|
||||
if _, ok := enabledProviders[provider]; ok {
|
||||
continue
|
||||
}
|
||||
delete(keys.ByProvider, provider)
|
||||
delete(keys.BaseURLByProvider, provider)
|
||||
}
|
||||
if _, ok := enabledProviders[NormalizeProvider("openai")]; !ok {
|
||||
keys.OpenAI = ""
|
||||
}
|
||||
if _, ok := enabledProviders[NormalizeProvider("anthropic")]; !ok {
|
||||
keys.Anthropic = ""
|
||||
}
|
||||
}
|
||||
|
||||
func newChatModel(provider, modelID, displayName string) codersdk.ChatModel {
|
||||
name := strings.TrimSpace(displayName)
|
||||
if name == "" {
|
||||
|
||||
@@ -21,6 +21,166 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestResolveUserProviderKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
configuredProvider := func(id uuid.UUID, provider string, centralEnabled bool, centralKey string, allowUser bool, allowCentralFallback bool) chatprovider.ConfiguredProvider {
|
||||
return chatprovider.ConfiguredProvider{
|
||||
ProviderID: id,
|
||||
Provider: provider,
|
||||
APIKey: centralKey,
|
||||
CentralAPIKeyEnabled: centralEnabled,
|
||||
AllowUserAPIKey: allowUser,
|
||||
AllowCentralAPIKeyFallback: allowCentralFallback,
|
||||
}
|
||||
}
|
||||
|
||||
userProviderKey := func(id uuid.UUID, apiKey string) chatprovider.UserProviderKey {
|
||||
return chatprovider.UserProviderKey{
|
||||
ChatProviderID: id,
|
||||
APIKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
openAIProviderID := uuid.MustParse("00000000-0000-0000-0000-000000000001")
|
||||
anthropicProviderID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fallback chatprovider.ProviderAPIKeys
|
||||
providers []chatprovider.ConfiguredProvider
|
||||
userKeys []chatprovider.UserProviderKey
|
||||
wantAvailability map[string]chatprovider.ProviderAvailability
|
||||
wantKeys map[string]string
|
||||
}{
|
||||
{
|
||||
name: "CentralOnlyKeyPresent",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", false, false)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-central",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CentralOnlyKeyMissing",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "", false, false)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UserOnlyUserHasKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, false, "sk-central", true, false)},
|
||||
userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-user",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UserOnlyUserHasNoKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, false, "sk-central", true, false)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOffUserHasKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, false)},
|
||||
userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-user",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOffUserHasNoKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, false)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOnUserHasKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, true)},
|
||||
userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-user",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOnUserHasNoKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, true)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-central",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOnCentralKeyEmptyUserHasNoKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "", true, true)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MultipleProvidersDifferentPolicies",
|
||||
providers: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", false, false),
|
||||
configuredProvider(anthropicProviderID, fantasyanthropic.Name, false, "", true, false),
|
||||
},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
fantasyanthropic.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-central",
|
||||
fantasyanthropic.Name: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keys, availability := chatprovider.ResolveUserProviderKeys(tt.fallback, tt.providers, tt.userKeys)
|
||||
|
||||
require.Len(t, availability, len(tt.wantAvailability))
|
||||
for provider, wantAvailability := range tt.wantAvailability {
|
||||
gotAvailability, ok := availability[provider]
|
||||
require.True(t, ok, "expected availability for provider %q", provider)
|
||||
require.Equal(t, wantAvailability, gotAvailability)
|
||||
require.Equal(t, tt.wantKeys[provider], keys.APIKey(provider))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningEffortFromChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -91,6 +251,413 @@ func TestReasoningEffortFromChat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUserProviderKeys_UnavailableReason(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
provider chatprovider.ConfiguredProvider
|
||||
wantReason codersdk.ChatModelProviderUnavailableReason
|
||||
}{
|
||||
{
|
||||
name: "FallbackConfiguredWithoutCentralKeyReturnsUserAPIKeyRequired",
|
||||
provider: chatprovider.ConfiguredProvider{
|
||||
Provider: "anthropic",
|
||||
CentralAPIKeyEnabled: true,
|
||||
AllowUserAPIKey: true,
|
||||
AllowCentralAPIKeyFallback: true,
|
||||
},
|
||||
wantReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
},
|
||||
{
|
||||
name: "UserKeyRequiredWithoutFallback",
|
||||
provider: chatprovider.ConfiguredProvider{
|
||||
Provider: "anthropic",
|
||||
CentralAPIKeyEnabled: true,
|
||||
AllowUserAPIKey: true,
|
||||
},
|
||||
wantReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keys, availability := chatprovider.ResolveUserProviderKeys(
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
[]chatprovider.ConfiguredProvider{tt.provider},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Empty(t, keys.APIKey(tt.provider.Provider))
|
||||
resolved, ok := availability[tt.provider.Provider]
|
||||
require.True(t, ok)
|
||||
require.False(t, resolved.Available)
|
||||
require.Equal(t, tt.wantReason, resolved.UnavailableReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListConfiguredModels_PolicyAwareAvailability(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
configuredProvider := func(provider string, apiKey string) chatprovider.ConfiguredProvider {
|
||||
return chatprovider.ConfiguredProvider{
|
||||
ProviderID: uuid.New(),
|
||||
Provider: provider,
|
||||
APIKey: apiKey,
|
||||
}
|
||||
}
|
||||
enabledProviders := func(providers ...string) map[string]struct{} {
|
||||
result := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
result[chatprovider.NormalizeProvider(provider)] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
catalog := chatprovider.NewModelCatalog()
|
||||
tests := []struct {
|
||||
name string
|
||||
configuredProviders []chatprovider.ConfiguredProvider
|
||||
configuredModels []chatprovider.ConfiguredModel
|
||||
availabilityByProvider map[string]chatprovider.ProviderAvailability
|
||||
enabledProviders map[string]struct{}
|
||||
want codersdk.ChatModelsResponse
|
||||
}{
|
||||
{
|
||||
name: "PolicyUnavailableOverridesConfiguredKey",
|
||||
configuredProviders: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(fantasyopenai.Name, "sk-central"),
|
||||
},
|
||||
configuredModels: []chatprovider.ConfiguredModel{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4",
|
||||
}},
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
Models: []codersdk.ChatModel{{
|
||||
ID: fantasyopenai.Name + ":gpt-4",
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "gpt-4",
|
||||
}},
|
||||
}}},
|
||||
},
|
||||
{
|
||||
name: "PolicyAvailableMarksProviderAvailable",
|
||||
configuredProviders: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(fantasyanthropic.Name, "sk-central"),
|
||||
},
|
||||
configuredModels: []chatprovider.ConfiguredModel{{
|
||||
Provider: fantasyanthropic.Name,
|
||||
Model: "claude-3-5-sonnet",
|
||||
}},
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyanthropic.Name: {Available: true},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyanthropic.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyanthropic.Name,
|
||||
Available: true,
|
||||
Models: []codersdk.ChatModel{{
|
||||
ID: fantasyanthropic.Name + ":claude-3-5-sonnet",
|
||||
Provider: fantasyanthropic.Name,
|
||||
Model: "claude-3-5-sonnet",
|
||||
DisplayName: "claude-3-5-sonnet",
|
||||
}},
|
||||
}}},
|
||||
},
|
||||
{
|
||||
name: "DisabledProviderOmitted",
|
||||
configuredProviders: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(fantasyanthropic.Name, "sk-anthropic"),
|
||||
configuredProvider(fantasyopenai.Name, "sk-openai"),
|
||||
},
|
||||
configuredModels: []chatprovider.ConfiguredModel{
|
||||
{Provider: fantasyanthropic.Name, Model: "claude-3-5-sonnet"},
|
||||
{Provider: fantasyopenai.Name, Model: "gpt-4"},
|
||||
},
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyanthropic.Name: {Available: true},
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: true,
|
||||
Models: []codersdk.ChatModel{{
|
||||
ID: fantasyopenai.Name + ":gpt-4",
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "gpt-4",
|
||||
}},
|
||||
}}},
|
||||
},
|
||||
{
|
||||
name: "MissingAvailabilityDefaultsToMissingAPIKey",
|
||||
configuredProviders: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(fantasyopenai.Name, "sk-central"),
|
||||
},
|
||||
configuredModels: []chatprovider.ConfiguredModel{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4o",
|
||||
}},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey,
|
||||
Models: []codersdk.ChatModel{{
|
||||
ID: fantasyopenai.Name + ":gpt-4o",
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4o",
|
||||
DisplayName: "gpt-4o",
|
||||
}},
|
||||
}}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, ok := catalog.ListConfiguredModels(
|
||||
tt.configuredProviders,
|
||||
tt.configuredModels,
|
||||
tt.availabilityByProvider,
|
||||
tt.enabledProviders,
|
||||
)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListConfiguredProviderAvailability_PolicyAwareFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
enabledProviders := func(providers ...string) map[string]struct{} {
|
||||
result := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
result[chatprovider.NormalizeProvider(provider)] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
catalog := chatprovider.NewModelCatalog()
|
||||
tests := []struct {
|
||||
name string
|
||||
availabilityByProvider map[string]chatprovider.ProviderAvailability
|
||||
enabledProviders map[string]struct{}
|
||||
want codersdk.ChatModelsResponse
|
||||
}{
|
||||
{
|
||||
name: "EnabledProvidersUsePolicyAvailability",
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyanthropic.Name: {
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
},
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyanthropic.Name, fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{
|
||||
{
|
||||
Provider: fantasyanthropic.Name,
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
Models: []codersdk.ChatModel{},
|
||||
},
|
||||
{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: true,
|
||||
Models: []codersdk.ChatModel{},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "DisabledSupportedProviderOmitted",
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyanthropic.Name: {Available: true},
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: true,
|
||||
Models: []codersdk.ChatModel{},
|
||||
}}},
|
||||
},
|
||||
{
|
||||
name: "MissingAvailabilityDefaultsToMissingAPIKey",
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey,
|
||||
Models: []codersdk.ChatModel{},
|
||||
}}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := catalog.ListConfiguredProviderAvailability(
|
||||
tt.availabilityByProvider,
|
||||
tt.enabledProviders,
|
||||
)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneDisabledProviderKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
enabledProviders := func(providers ...string) map[string]struct{} {
|
||||
result := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
result[chatprovider.NormalizeProvider(provider)] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keys chatprovider.ProviderAPIKeys
|
||||
enabledProviders map[string]struct{}
|
||||
want chatprovider.ProviderAPIKeys
|
||||
}{
|
||||
{
|
||||
name: "DisabledProviderEntriesRemoved",
|
||||
keys: chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OpenAIDisabledClearsLegacyField",
|
||||
keys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyanthropic.Name),
|
||||
want: chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "AnthropicDisabledClearsLegacyField",
|
||||
keys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "AllEnabledLeavesKeysUnchanged",
|
||||
keys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name, fantasyanthropic.Name),
|
||||
want: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keys := tt.keys
|
||||
chatprovider.PruneDisabledProviderKeys(&keys, tt.enabledProviders)
|
||||
require.Equal(t, tt.want, keys)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCoderHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -670,7 +670,11 @@ func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp
|
||||
"created": resp.Created,
|
||||
"model": resp.Model,
|
||||
"output": outputs,
|
||||
"usage": resp.Usage,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": resp.Usage.PromptTokens,
|
||||
"output_tokens": resp.Usage.CompletionTokens,
|
||||
"total_tokens": resp.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user