Compare commits
234 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f7cf3e7fc2 | |||
| cc143c8990 | |||
| 3def04a3ee | |||
| d4a9c63e91 | |||
| 00217fefa5 | |||
| fce05d0428 | |||
| 2ebc076b9e | |||
| e71dc6dd4d | |||
| ca3ae3643d | |||
| ed5c06f039 | |||
| 1221622bf0 | |||
| aa5ec0bfcc | |||
| 16add93908 | |||
| fe13fd065c | |||
| fb788530b3 | |||
| f4dc8f6b11 | |||
| bbeff0d4b5 | |||
| 78fa8094cc | |||
| a85e00eed0 | |||
| da50a34414 | |||
| cd784c755a | |||
| eb4860aac3 | |||
| 07fbe8ca7d | |||
| ba0a64d483 | |||
| 7757cd8e08 | |||
| fc1e0beb3b | |||
| 3a4a0b7270 | |||
| 11c1afb5e9 | |||
| be2e641162 | |||
| 83e2699914 | |||
| 515ba209fd | |||
| d15bfc2cb0 | |||
| 308053b0e4 | |||
| 7c29355e84 | |||
| 7c048d8eb4 | |||
| e81275a91c | |||
| 4a363b0d85 | |||
| 7dc81bdef1 | |||
| ee855f9618 | |||
| b1c42bb630 | |||
| c048a4093e | |||
| 1cc23a3144 | |||
| dee5ec51c0 | |||
| e3c59c00cd | |||
| ba734f8b10 | |||
| 28062862a0 | |||
| 129e3509a3 | |||
| 53a1b6d67e | |||
| 8c8b307b97 | |||
| 12f87acad6 | |||
| 7198f9040d | |||
| ddafdbcbce | |||
| 196dc51edf | |||
| 2a51687ff3 | |||
| 19e44f4136 | |||
| 2ea89e1f1b | |||
| faa5db0cf0 | |||
| 3758b02595 | |||
| 7861fcf1f6 | |||
| 4b52656958 | |||
| f5b98aa12d | |||
| 7ddde0887b | |||
| bec426b24f | |||
| d6df78c9b9 | |||
| 19390a5841 | |||
| 2d03f7fd3d | |||
| 153a66b579 | |||
| 5cba59af79 | |||
| 1d16ff1ca6 | |||
| b86161e0a6 | |||
| a164d508cf | |||
| b9f140e53e | |||
| 7f7b13f0ab | |||
| e2bbd12137 | |||
| e769d1bd7d | |||
| cccb680ec2 | |||
| e8fb418820 | |||
| 2c5e003c91 | |||
| f44a8994da | |||
| 84b94a8376 | |||
| 2a990ce758 | |||
| c86f1288f1 | |||
| 9440adf435 | |||
| 755e8be5ad | |||
| c9e335c453 | |||
| 2d1f35f8a6 | |||
| b0036af57b | |||
| 2953245862 | |||
| 5d07014f9f | |||
| 002e88fefc | |||
| bbf3fbc830 | |||
| 9fa103929a | |||
| acd2ff63a7 | |||
| e3e17e15f7 | |||
| af678606fc | |||
| 3190406de3 | |||
| 3ce82bb885 | |||
| 348a3bd693 | |||
| 75f1503b41 | |||
| c33cd19a05 | |||
| adcea865c7 | |||
| 5e3bccd96c | |||
| 3950947c58 | |||
| b3d5b8d13c | |||
| a00afe4b5a | |||
| a5cc579453 | |||
| ef3aade647 | |||
| 3cc31de57a | |||
| d2c308e481 | |||
| 953c3bdc0f | |||
| ca879ffae6 | |||
| 0880a4685b | |||
| 3f8e3007d8 | |||
| 8e57498a87 | |||
| 0fb3e5cba5 | |||
| 7fb93dbf0e | |||
| cf500b95b9 | |||
| 6a2f389110 | |||
| 027f93c913 | |||
| 509e89d5c4 | |||
| 378f11d6dc | |||
| f2845f6622 | |||
| 076e97aa66 | |||
| 2875053b83 | |||
| 548a648dcb | |||
| 7d0a49f54b | |||
| f77d0c1649 | |||
| 9f51c44772 | |||
| 73f6cd8169 | |||
| 4c97b63d79 | |||
| 28484536b6 | |||
| 7a5fd4c790 | |||
| 8f73e46c2f | |||
| 56171306ff | |||
| 0b07ce2a97 | |||
| f2a7fdacfe | |||
| 0e78156bcd | |||
| bc5e4b5d54 | |||
| 13dfc9a9bb | |||
| 54738e9e14 | |||
| 78986efed8 | |||
| 4d2b0a2f82 | |||
| f7aa46c4ba | |||
| 4bf46c4435 | |||
| be99b3cb74 | |||
| 588beb0a03 | |||
| bfeb91d9cd | |||
| a399aa8c0c | |||
| 386b449273 | |||
| 565cf846de | |||
| a2799560eb | |||
| 73bde99495 | |||
| a708e9d869 | |||
| 91217a97b9 | |||
| 399080e3bf | |||
| 50d9d510c5 | |||
| eda1bba969 | |||
| 808dd64ef6 | |||
| 04f7d19645 | |||
| 71a492a374 | |||
| 8c494e2a77 | |||
| 839165818b | |||
| 6b77fa74a1 | |||
| 25e9fa7120 | |||
| 60065f6c08 | |||
| bcdc35ee3e | |||
| a5c72ba396 | |||
| 3f55b35f68 | |||
| 97a27d3c09 | |||
| 4ed9094305 | |||
| d973a709df | |||
| 50c0c89503 | |||
| 0ec0f8faaf | |||
| 9b4d15db9b | |||
| 9e33035631 | |||
| 83b2f85d63 | |||
| c4ef94aacf | |||
| d678c6fb16 | |||
| 86c3983fc0 | |||
| 2312e5c428 | |||
| f35f2a28e6 | |||
| 4fab372bdc | |||
| aa81238cd0 | |||
| f87d6e6e82 | |||
| 113aaa79a0 | |||
| f3a8096ff6 | |||
| beece6d351 | |||
| 58f744a5c1 | |||
| 0f86c4237e | |||
| 02b58534a0 | |||
| e35fa8b9ee | |||
| 1358233c83 | |||
| ea4070c0ce | |||
| 1b2fab8306 | |||
| 94e5de22f7 | |||
| 6cbb7c6da7 | |||
| fc60a6bf9b | |||
| a52153968d | |||
| d18e700699 | |||
| 0234e8fffd | |||
| fea4560a64 | |||
| 6dee7cf11d | |||
| d4fc4e0837 | |||
| 8da45c14bc | |||
| bfee7e6245 | |||
| 52b5d5fdc6 | |||
| cc4cca90fd | |||
| 081d91982a | |||
| 00cd7b7346 | |||
| 801e57d430 | |||
| e937f89081 | |||
| 5c7057a67f | |||
| 249ef7c567 | |||
| 81fe7543b4 | |||
| 61d2a4a9b8 | |||
| b23c07cf23 | |||
| 87aafd4ae2 | |||
| 4d74603045 | |||
| 847a88c6ca | |||
| a0283ff775 | |||
| f164463c6a | |||
| 4f063cdc47 | |||
| d175e799da | |||
| 3fb7c6264f | |||
| 09d2588e2a | |||
| 8eade29e68 | |||
| 15f2fa55c6 | |||
| 2ff329b68a | |||
| ad3d934290 | |||
| 21c2acbad5 | |||
| 411714cd73 | |||
| 61e31ec5cc | |||
| 17aea0b19c | |||
| 5112ab7da9 |
@@ -111,8 +111,8 @@ Tier 2 file filters:
|
||||
|
||||
- **Modernization Reviewer**: one instance per language present in the diff. Filter by extension:
|
||||
- Go: `*.go` — reference `.claude/docs/GO.md` before reviewing.
|
||||
- TypeScript: `*.ts` `*.tsx`
|
||||
- React: `*.tsx` `*.jsx`
|
||||
- TypeScript: `*.ts` `*.tsx`: reference `.agents/skills/deep-review/references/typescript.md` before reviewing.
|
||||
- React: `*.tsx` `*.jsx`: reference `.agents/skills/deep-review/references/react.md` before reviewing.
|
||||
|
||||
`.tsx` files match both TypeScript and React filters. Spawn both instances when the diff contains `.tsx` changes — TS covers language-level patterns; React covers component and hooks patterns. Before spawning, verify each instance's filter produces a non-empty diff. Skip instances whose filtered diff is empty.
|
||||
|
||||
@@ -155,9 +155,11 @@ File scope: {filter from step 2}.
|
||||
Output file: {REVIEW_DIR}/{role-name}.md
|
||||
```
|
||||
|
||||
For the Modernization Reviewer (Go), add after the methodology line:
|
||||
For Modernization Reviewer instances, add the language reference after the methodology line:
|
||||
|
||||
> Read `.claude/docs/GO.md` as your Go language reference before reviewing.
|
||||
- **Go:** `Read .claude/docs/GO.md as your Go language reference before reviewing.`
|
||||
- **TypeScript:** `Read .agents/skills/deep-review/references/typescript.md as your TypeScript language reference before reviewing.`
|
||||
- **React:** `Read .agents/skills/deep-review/references/react.md as your React language reference before reviewing.`
|
||||
|
||||
For re-reviews, append to both Tier 1 and Tier 2 prompts:
|
||||
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
# Modern React (18–19.2) + Compiler 1.0 — Reference
|
||||
|
||||
Reference for writing idiomatic React. Covers what changed, what it replaced, and what to reach for. Includes React Compiler patterns — what the compiler handles automatically, what it changes semantically, and how to verify its behavior empirically. Scope: client-side SPA patterns only. Server Components, `use server`, and `use client` directives are framework-specific and omitted. Check the project's React version and compiler config before reaching for newer APIs.
|
||||
|
||||
## How modern React thinks differently
|
||||
|
||||
**Concurrent rendering** (18): React can now pause, interrupt, and resume renders. This is the foundation everything else builds on. Most existing code "just works," but components that produce side effects during render (mutations, subscriptions, network calls in the render body) are unsafe and will misbehave. Concurrent features are opt-in — they only activate when you use a concurrent API like `startTransition` or `useDeferredValue`.
|
||||
|
||||
**Urgent vs. non-urgent updates** (18): The `startTransition` / `useTransition` API introduces a formal split between updates that must feel immediate (typing, clicking) and updates that can be interrupted (filtering a large list, navigating to a new screen). Non-urgent updates yield to urgent ones mid-render. Use this instead of `setTimeout` or manual debounce when you want the UI to stay responsive during expensive re-renders.
|
||||
|
||||
**Actions** (19): Async functions passed to `startTransition` are called "Actions." They automatically manage pending state, error handling, and optimistic updates as a unit. The `useActionState` hook and `<form action={fn}>` prop are built on this. The pattern replaces the hand-rolled `isPending/setIsPending` + `try/catch` + `setError` boilerplate that was previously necessary for every data mutation.
|
||||
|
||||
**Automatic batching** (18): State updates are now batched everywhere — inside `setTimeout`, `Promise.then`, native event handlers, etc. Previously batching only happened inside React-managed event handlers. If you genuinely need a synchronous flush, use `flushSync`.
|
||||
|
||||
**Automatic memoization** (Compiler 1.0): React Compiler is a build-time Babel plugin that automatically inserts memoization into components and hooks. It replaces manual `useMemo`, `useCallback`, and `React.memo` — including conditional memoization and memoization after early returns, which manual APIs cannot express. The compiler only processes components and hooks, not standalone functions. It understands data flow and mutability through its own HIR (High-level Intermediate Representation), so it can memoize more granularly than a human would. Projects adopt it incrementally — typically via path-based Babel overrides or the `"use memo"` directive. Components that violate the Rules of React are silently skipped (no build error), so the automated lint tools that check compiler compatibility matter.
|
||||
|
||||
## Replace these patterns
|
||||
|
||||
The left column reflects patterns common before React 18/19. Write the right column instead. The "Since" column tells you the minimum React version required.
|
||||
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ----------------------------------------------------------------- | ------------------------------------------------------------------------------ | ----- |
|
||||
| `ReactDOM.render(<App />, el)` | `createRoot(el).render(<App />)` | 18 |
|
||||
| `ReactDOM.hydrate(<App />, el)` | `hydrateRoot(el, <App />)` | 18 |
|
||||
| `ReactDOM.unmountComponentAtNode(el)` | `root.unmount()` | 18 |
|
||||
| `ReactDOM.findDOMNode(this)` | DOM ref: `const ref = useRef(); ref.current` | 18 |
|
||||
| `<Context.Provider value={v}>` | `<Context value={v}>` | 19 |
|
||||
| `React.forwardRef((props, ref) => ...)` | `function Comp({ ref, ...props }) { ... }` (ref as a regular prop) | 19 |
|
||||
| String ref `ref="input"` in class components | Callback ref or `createRef()` | 19 |
|
||||
| `Heading.propTypes = { ... }` | TypeScript / ES6 type annotations | 19 |
|
||||
| `Component.defaultProps = { ... }` on function components | ES6 default parameters `({ text = 'Hi' })` | 19 |
|
||||
| Legacy Context: `contextTypes` + `getChildContext` | `React.createContext()` + `contextType` | 19 |
|
||||
| `import { act } from 'react-dom/test-utils'` | `import { act } from 'react'` | 19 |
|
||||
| `import ShallowRenderer from 'react-test-renderer/shallow'` | `import ShallowRenderer from 'react-shallow-renderer'` | 19 |
|
||||
| Manual `isPending` state around async calls | `const [isPending, startTransition] = useTransition()` | 18 |
|
||||
| Manual optimistic state + revert logic | `useOptimistic(currentValue)` | 19 |
|
||||
| `useEffect` to subscribe to external stores | `useSyncExternalStore(subscribe, getSnapshot)` | 18 |
|
||||
| Hand-rolled unique ID (counter, random, index) | `useId()` — SSR-safe, hydration-safe | 18 |
|
||||
| `useEffect` to inject `<title>` or `<meta>` / `react-helmet` | Render `<title>`, `<meta>`, `<link>` directly in components; React hoists them | 19 |
|
||||
| `ReactDOM.useFormState(action, initial)` (Canary name) | `useActionState(action, initial)` | 19 |
|
||||
| `useReducer<React.Reducer<State, Action>>(reducer)` | `useReducer(reducer)` — infers from the reducer function | 19 |
|
||||
| `<div ref={current => (instance = current)} />` (implicit return) | `<div ref={current => { instance = current }} />` (explicit block body) | 19 |
|
||||
| `useRef<T>()` with no argument | `useRef<T>(undefined)` or `useRef<T \| null>(null)` — argument is now required | 19 |
|
||||
| `MutableRefObject<T>` type annotation | `RefObject<T>` — all refs are mutable now; `MutableRefObject` is deprecated | 19 |
|
||||
| `React.createFactory('button')` | `<button />` JSX | 19 |
|
||||
| `useMemo(() => expr, [deps])` in compiled components | `const val = expr;` — compiler memoizes automatically | C 1.0 |
|
||||
| `useCallback(fn, [deps])` in compiled components | `const fn = () => { ... };` — compiler memoizes automatically | C 1.0 |
|
||||
| `React.memo(Component)` in compiled components | Plain component — compiler skips re-render when props are unchanged | C 1.0 |
|
||||
| `eslint-plugin-react-compiler` (standalone) | `eslint-plugin-react-hooks@latest` (compiler rules merged into recommended) | C 1.0 |
|
||||
| `useRef` + `useLayoutEffect` for stable callbacks | `useEffectEvent(fn)` — compiler handles both, but `useEffectEvent` is clearer | 19.2 |
|
||||
|
||||
## New capabilities
|
||||
|
||||
These enable things that weren't practical before. Reach for them in the described situations.
|
||||
|
||||
| What | Since | When to use it |
|
||||
| -------------------------------------------------------------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `useTransition()` / `startTransition()` | 18 | Mark a state update as non-urgent so React can interrupt it to handle clicks or keystrokes. The `isPending` boolean lets you show a loading indicator without blocking the UI. |
|
||||
| `useDeferredValue(value, initialValue?)` | 18 / 19 | Defer re-rendering a slow subtree: pass the deferred value as a prop, wrap the expensive child in `memo`. Unlike debounce, uses no fixed timeout — renders as soon as the browser is idle. The `initialValue` arg (19) avoids a flash on first render. |
|
||||
| `useId()` | 18 | Generate a stable, SSR-consistent ID for accessibility attributes (`htmlFor`, `aria-describedby`). Do not use for list keys. |
|
||||
| `useSyncExternalStore(subscribe, getSnapshot, getServerSnapshot?)` | 18 | Subscribe to external (non-React) state stores safely under concurrent rendering. Preferred over `useEffect`-based subscriptions in libraries. |
|
||||
| `useActionState(action, initialState)` | 19 | Manage an async mutation: returns `[state, wrappedAction, isPending]`. Handles pending, result, and error state as a unit. Replaces the manual `isPending` + `try/catch` + `setError` pattern. |
|
||||
| `useOptimistic(currentValue)` | 19 | Show a speculative value while an async Action is in flight. Returns `[optimisticValue, setOptimistic]`. React automatically reverts to `currentValue` when the transition settles. |
|
||||
| `use(promiseOrContext)` | 19 | Read a promise or Context value inside a component or custom hook. Unlike hooks, `use` can be called conditionally (after early returns). Promises must come from a cache — do not create them during render. |
|
||||
| `useFormStatus()` (from `react-dom`) | 19 | Read `{ pending, data, method, action }` of the nearest parent `<form>` Action. Works across component boundaries without prop drilling — useful for submit buttons inside design-system components. |
|
||||
| `useEffectEvent(fn)` | 19.2 | Extract a non-reactive callback from an effect. The function sees the latest props/state without being listed in deps, and is never stale. Replaces the `useRef`-and-mutate-in-layout-effect workaround for stable event-like callbacks. The compiler has built-in knowledge of this hook and correctly prunes its return value from effect dependency arrays. Both `useEffectEvent` and the old ref workaround compile cleanly; `useEffectEvent` is preferred for clarity. |
|
||||
| `<Activity>` | 19.2 | Hide part of the UI while preserving its state and DOM. React deprioritizes updates to hidden content. Use via framework APIs for route prerendering or tab preservation — not a direct replacement for CSS `visibility`. |
|
||||
| `captureOwnerStack()` | 19.1 | Dev-only API that returns a string showing which components are responsible for rendering the current component (owner stack, not call stack). Useful for custom error overlays. Returns `null` in production. |
|
||||
| `<form action={fn}>` | 19 | Pass an async function as a form's `action` prop. React handles submission, pending state, and automatic form reset on success. Works with `useActionState` and `useFormStatus`. |
|
||||
| Ref cleanup function | 19 | Return a cleanup function from a ref callback: `ref={el => { ...; return () => cleanup(); }}`. React calls it on unmount. Replaces the pattern of checking `el === null` in the callback. |
|
||||
| `<link rel="stylesheet" precedence="default">` | 19 | Declare a stylesheet next to the component that needs it. React deduplicates and inserts it in the correct order before revealing Suspense content. |
|
||||
| `preinit`, `preload`, `prefetchDNS`, `preconnect` (from `react-dom`) | 19 | Imperatively hint the browser to load resources early. Call from render or event handlers. React deduplicates hints across the component tree. |
|
||||
| React Compiler (`babel-plugin-react-compiler`) | C 1.0 | Build-time automatic memoization for components and hooks. Install, add to Babel/Vite pipeline. Projects typically start with path-based overrides to compile a subset of files. |
|
||||
| `"use memo"` directive | C 1.0 | Opt a single function into compilation when using `compilationMode: 'annotation'`. Place at the start of the function body. Module-level `"use memo"` at the top of a file compiles all functions in that file. |
|
||||
| `"use no memo"` directive | C 1.0 | Temporary escape hatch — skip compilation for a specific component or hook that causes a runtime regression. Not a permanent solution. Place at the start of the function body. |
|
||||
| Compiler-powered ESLint rules | C 1.0 | Rules for purity, refs, set-state-in-render, immutability, etc. now ship in `eslint-plugin-react-hooks` recommended preset. Surface Rules-of-React violations even without the compiler installed. Note: some projects use Biome instead — check project lint config. |
|
||||
|
||||
## Key APIs
|
||||
|
||||
### `useTransition` and `startTransition` (18)
|
||||
|
||||
`useTransition` returns `[isPending, startTransition]`. Wrap any state update that is not directly tied to the user's current gesture inside `startTransition`. React will render the old UI while computing the new one, and `isPending` is `true` during that window.
|
||||
|
||||
In React 19, `startTransition` can accept an async function (an "Action"). React sets `isPending` to `true` for the entire duration of the async work, not just during the synchronous part.
|
||||
|
||||
```tsx
|
||||
// 18: synchronous transition
|
||||
const [isPending, startTransition] = useTransition();
|
||||
startTransition(() => setQuery(input));
|
||||
|
||||
// 19: async Action — isPending stays true until the await settles
|
||||
startTransition(async () => {
|
||||
const err = await updateName(name);
|
||||
if (err) setError(err);
|
||||
});
|
||||
```
|
||||
|
||||
Use `startTransition` (the module-level export) when you cannot use the hook (outside a component, in a router callback, etc.).
|
||||
|
||||
### `useDeferredValue` (18 / 19)
|
||||
|
||||
Creates a "lagging" copy of a value. Pass it to a memoized, expensive component so that React can render the stale UI while computing the updated one.
|
||||
|
||||
```tsx
|
||||
// 19: initialValue shows '' on first render; avoids loading flash
|
||||
const deferred = useDeferredValue(searchQuery, "");
|
||||
return <Results query={deferred} />; // Results wrapped in memo
|
||||
```
|
||||
|
||||
`deferred !== searchQuery` while the deferred render is in progress — use this to show a "stale" indicator.
|
||||
|
||||
### `useActionState` (19)
|
||||
|
||||
Replaces the `useState` + `isPending` + `try/catch` + `setError` boilerplate for any async operation that can be retried or submitted as a form.
|
||||
|
||||
```tsx
|
||||
const [error, submitAction, isPending] = useActionState(
|
||||
async (prevState, formData) => {
|
||||
const err = await updateName(formData.get("name"));
|
||||
if (err) return err; // returned value becomes next state
|
||||
redirect("/profile");
|
||||
return null;
|
||||
},
|
||||
null, // initialState
|
||||
);
|
||||
|
||||
// Use submitAction as the form's action prop or call it directly
|
||||
<form action={submitAction}>
|
||||
<input name="name" />
|
||||
<button disabled={isPending}>Save</button>
|
||||
{error && <p>{error}</p>}
|
||||
</form>;
|
||||
```
|
||||
|
||||
### `useOptimistic` (19)
|
||||
|
||||
Shows a speculative value immediately while an async Action is in progress. React automatically reverts to the server-confirmed value when the Action resolves or rejects.
|
||||
|
||||
```tsx
|
||||
const [optimisticName, setOptimisticName] = useOptimistic(currentName);
|
||||
|
||||
const submit = async (formData) => {
|
||||
const newName = formData.get("name");
|
||||
setOptimisticName(newName); // shows immediately
|
||||
await updateName(newName); // reverts if this throws
|
||||
};
|
||||
```
|
||||
|
||||
### `use()` (19)
|
||||
|
||||
Unlike hooks, `use` can appear after conditional statements. Two primary uses:
|
||||
|
||||
**Reading a promise** (must be stable — from a cache, not created inline):
|
||||
|
||||
```tsx
|
||||
function Comments({ commentsPromise }) {
|
||||
const comments = use(commentsPromise); // suspends until resolved
|
||||
return comments.map((c) => <p key={c.id}>{c.text}</p>);
|
||||
}
|
||||
```
|
||||
|
||||
**Reading context after an early return** (hooks cannot appear after `return`):
|
||||
|
||||
```tsx
|
||||
function Heading({ children }) {
|
||||
if (!children) return null;
|
||||
const theme = use(ThemeContext); // valid here; hooks would not be
|
||||
return <h1 style={{ color: theme.color }}>{children}</h1>;
|
||||
}
|
||||
```
|
||||
|
||||
### `useSyncExternalStore` (18)
|
||||
|
||||
The correct way for libraries (and app code) to subscribe to non-React state. Prevents tearing under concurrent rendering.
|
||||
|
||||
```tsx
|
||||
const value = useSyncExternalStore(
|
||||
store.subscribe, // called when store changes
|
||||
store.getSnapshot, // returns current value (must be stable reference if unchanged)
|
||||
store.getServerSnapshot, // optional: for SSR
|
||||
);
|
||||
```
|
||||
|
||||
## Verifying compiler behavior
|
||||
|
||||
The compiler is a black box unless you inspect its output. When reviewing code in compiled paths, run the compiler on the specific code to see what it actually does. Do not guess — verify.
|
||||
|
||||
**Run the compiler on a code snippet:**
|
||||
|
||||
```sh
|
||||
cd site && node -e "
|
||||
const {transformSync} = require('@babel/core');
|
||||
const code = \`<paste component here>\`;
|
||||
const diagnostics = [];
|
||||
const result = transformSync(code, {
|
||||
plugins: [
|
||||
['@babel/plugin-syntax-typescript', {isTSX: true}],
|
||||
['babel-plugin-react-compiler', {
|
||||
logger: {
|
||||
logEvent(_, event) {
|
||||
if (event.kind === 'CompileError' || event.kind === 'CompileSkip') {
|
||||
diagnostics.push(event.detail?.toString?.()?.substring(0, 200));
|
||||
}
|
||||
},
|
||||
},
|
||||
}],
|
||||
],
|
||||
filename: 'test.tsx',
|
||||
});
|
||||
console.log('Compiled:', result.code.includes('_c('));
|
||||
if (diagnostics.length) console.log('Diagnostics:', diagnostics);
|
||||
console.log(result.code);
|
||||
"
|
||||
```
|
||||
|
||||
**Reading compiled output:**
|
||||
|
||||
- `const $ = _c(N)` — allocates N memoization cache slots.
|
||||
- `if ($[n] !== dep)` — cache invalidation guard. Re-computes when `dep` changes (referential equality).
|
||||
- `if ($[n] === Symbol.for("react.memo_cache_sentinel"))` — one-time initialization. Runs once on first render, cached forever after. This is how the compiler handles expressions with no reactive dependencies.
|
||||
- `_temp` functions — pure callbacks the compiler hoisted out of the component body.
|
||||
|
||||
**Check all compiled files at once:**
|
||||
|
||||
```sh
|
||||
cd site && pnpm run lint:compiler
|
||||
```
|
||||
|
||||
This runs the compiler on every file in the compiled paths and reports CompileError / CompileSkip diagnostics. Zero diagnostics means all functions compiled cleanly.
|
||||
|
||||
**What the compiler catches vs. what it does not:**
|
||||
|
||||
The compiler emits `CompileError` for mutations of props, state, or hook arguments during render, and for `ref.current` access during render. The project's lint pipeline catches these automatically — do not flag them in review.
|
||||
|
||||
The compiler does **not** flag impure function calls during render (`Math.random()`, `Date.now()`, `new Date()`). Instead it silently memoizes them with a sentinel guard, freezing the value after first render. This changes semantics without any diagnostic. Verify suspicious calls by running the compiler and checking for sentinel guards in the output.
|
||||
|
||||
## Pitfalls
|
||||
|
||||
Things that are easy to get wrong even when you know the modern API exists. Check your output against these.
|
||||
|
||||
**Effects run twice in development with StrictMode.** React 18 intentionally mounts → unmounts → remounts every component in dev to surface effects that are not resilient to remounting. This is not a bug. If an effect breaks on the second mount, it is missing a cleanup function. Write `return () => cleanup()` from every effect that sets up a subscription, timer, or external resource.
|
||||
|
||||
**Concurrent rendering can call render multiple times.** The render function (component body) may be called more than once before React commits to the DOM. Side effects (mutations, subscriptions, logging) in the render body will run multiple times. Move them into `useEffect` or event handlers.
|
||||
|
||||
**Do not create promises during render and pass them to `use()`.** A new promise is created every render, causing an infinite suspend-retry loop. Create the promise outside the component (module level), or use a caching library (SWR, React Query, `cache()` from React) to stabilize it.
|
||||
|
||||
**`useOptimistic` reverts automatically — do not fight it.** The optimistic value is a presentation layer only. When the Action settles, React replaces it with the real `currentValue` you passed in. Do not try to sync optimistic state back to your real state; let React handle the revert.
|
||||
|
||||
**`flushSync` opts out of automatic batching.** If third-party code or a browser API (e.g. `ResizeObserver`) calls `setState` and you need synchronous DOM flushing, wrap with `flushSync(() => setState(...))`. This is a last resort; prefer letting React batch.
|
||||
|
||||
**`forwardRef` still works in React 19 but will be deprecated.** Function components accept `ref` as a plain prop now. New code should use the prop directly. Existing `forwardRef` wrappers continue to work without changes; migrate when convenient.
|
||||
|
||||
**`<Activity>` does not unmount.** Content inside a hidden `<Activity>` boundary stays mounted. Effects keep running. Use it for preserving scroll position or form state, not for preventing expensive mounts — use lazy loading for that.
|
||||
|
||||
**TypeScript: implicit returns from ref callbacks are now type errors.** In React 19, returning anything other than a cleanup function (or nothing) from a ref callback is rejected by the TypeScript types. The most common case is arrow-function refs that implicitly return the DOM node:
|
||||
|
||||
```tsx
|
||||
// Error in React 19 types:
|
||||
<div ref={el => (instance = el)} />
|
||||
|
||||
// Fix — use a block body:
|
||||
<div ref={el => { instance = el; }} />
|
||||
```
|
||||
|
||||
**TypeScript: `useRef` now requires an argument.** `useRef<T>()` with no argument is a type error. Pass `undefined` for mutable refs or `null` for DOM refs you initialize on mount: `useRef<T>(undefined)` / `useRef<HTMLDivElement | null>(null)`.
|
||||
|
||||
**`useId` output format changed across versions.** React 18 produced `:r0:`. React 19.1 changed it to `«r0»`. React 19.2 changed it again to `_r0`. Do not parse or depend on the specific format — treat it as an opaque string.
|
||||
|
||||
**`useFormStatus` reads the nearest parent `<form>` with a function `action`.** It does not reflect native HTML form submissions — only React Actions. A submit button that is a sibling of `<form>` (rather than a descendant) will not see the form's status.
|
||||
|
||||
**Context as a provider (`<Context>`) requires React 19; `<Context.Provider>` still works.** Do not use `<Context>` shorthand in a codebase that needs to support React 18. The two forms can coexist during migration.
|
||||
|
||||
**Compiler freezes impure expressions silently.** `Math.random()`, `Date.now()`, `new Date()`, and `window.innerWidth` in a component body all compile without diagnostics. The compiler wraps them in a sentinel guard (`Symbol.for("react.memo_cache_sentinel")`) that runs the expression once and caches the result forever. The value never updates on re-render. Fix: move to a `useState` initializer (`useState(() => Math.random())`), `useEffect`, or event handler.
|
||||
|
||||
**Component granularity affects compiler optimization.** When one pattern in a component causes a `CompileError` (e.g., a necessary `ref.current` read during render), the compiler skips the **entire** component. If the rest of the component would benefit from compilation, extract the non-compilable pattern into a small child component. This keeps the parent compiled.
|
||||
|
||||
**The compiler only memoizes components and hooks.** Standalone utility functions (even expensive ones called during render) are not compiled. If a utility function is truly expensive, it still needs its own caching strategy outside of React (e.g., a module-level cache, `WeakMap`, etc.).
|
||||
|
||||
**Changing memoization can shift `useEffect` firing.** A value that was unstable before compilation may become stable after, causing an effect that depended on it to fire less often. Conversely, future compiler changes may alter memoization granularity. Effects that use memoized values as dependencies should be resilient to these changes — they should be true synchronization effects, not "run this when X changes" hacks.
|
||||
|
||||
## Behavioral changes that affect code
|
||||
|
||||
- **Automatic batching** (18): State updates in `setTimeout`, `Promise.then`, `addEventListener` callbacks, etc. are now batched into a single re-render. Previously only React synthetic event handlers were batched. Code that relied on unbatched updates (reading DOM synchronously after each `setState`) must use `flushSync`.
|
||||
|
||||
- **StrictMode double-invoke** (18): In development, every component is mounted → unmounted → remounted with the previous state. Every effect runs cleanup → setup twice on initial mount. `useMemo` and `useCallback` also double-invoke their functions. Production behavior is unchanged. If a test or component breaks under this, the component had a latent cleanup bug.
|
||||
|
||||
- **StrictMode ref double-invoke** (19): In development, ref callbacks are also invoked twice on mount (attach → detach → attach). Return a cleanup function from the ref callback to handle detach correctly.
|
||||
|
||||
- **StrictMode memoization reuse** (19): During the second pass of double-rendering, `useMemo` and `useCallback` now reuse the cached result from the first pass instead of calling the function again. Components that are already StrictMode-compatible should not notice a difference.
|
||||
|
||||
- **Suspense fallback commits immediately** (19): When a component suspends, React now commits the nearest `<Suspense>` fallback without waiting for sibling trees to finish rendering. After the fallback is shown, React "pre-warms" suspended siblings in the background. This makes fallbacks appear faster but changes the order of rendering work.
|
||||
|
||||
- **Error re-throwing removed** (19): Errors that are not caught by an Error Boundary are now reported to `window.reportError` (not re-thrown). Errors caught by an Error Boundary go to `console.error` once. If your production monitoring relied on the re-thrown error, add handlers to `createRoot`: `createRoot(el, { onUncaughtError, onCaughtError })`.
|
||||
|
||||
- **Transitions in `popstate` are synchronous** (19): Browser back/forward navigation triggers synchronous transition flushing. This ensures the URL and UI update together atomically during history navigation.
|
||||
|
||||
- **`useEffect` from discrete events flushes synchronously** (18): Effects triggered by a click or keydown (discrete events) are now flushed synchronously before the browser paints, consistent with `useLayoutEffect` for those cases.
|
||||
|
||||
- **Hydration mismatches treated as errors** (18 / improved in 19): Text content mismatches between server HTML and client render revert to client rendering up to the nearest `<Suspense>` boundary. React 19 logs a single diff instead of multiple warnings, making mismatches much easier to diagnose.
|
||||
|
||||
- **New JSX transform required** (19): The automatic JSX runtime introduced in 2020 (`react/jsx-runtime`) is now mandatory. The classic transform (which required `import React from 'react'` in every file) is no longer supported. Most toolchains have already shipped the new transform; check your Babel or TypeScript config if you see warnings.
|
||||
|
||||
- **UMD builds removed** (19): React no longer ships UMD bundles. Load via npm and a bundler, or use an ESM CDN (`import React from "https://esm.sh/react@19"`).
|
||||
|
||||
- **React Compiler automatic memoization** (Compiler 1.0): Build-time Babel plugin that inserts memoization into components and hooks. Components that follow the Rules of React are automatically memoized; components that violate them are silently skipped (no build error, no runtime change). The compiler can memoize conditionally and after early returns — things impossible with manual `useMemo`/`useCallback`. Works with React 17+ via `react-compiler-runtime`; best with React 19+. Projects adopt incrementally via path-based Babel overrides, `compilationMode: 'annotation'`, or the `"use memo"` / `"use no memo"` directives. Check the project's Vite/Babel config to know which paths are compiled. Compiled components show a "Memo ✨" badge in React DevTools.
|
||||
@@ -0,0 +1,199 @@
|
||||
# Modern TypeScript (5.0–6.0 RC) — Reference
|
||||
|
||||
Reference for writing idiomatic TypeScript. Covers what changed, what it replaced, and what to reach for. Respect the project's minimum TypeScript version: don't emit features from a version newer than what the project targets. Check `package.json` and `tsconfig.json` before writing code.
|
||||
|
||||
## How modern TypeScript thinks differently
|
||||
|
||||
The 5.x era resolves years of module system ambiguity and cleans house on legacy options. Three themes dominate:
|
||||
|
||||
**Module semantics are explicit.** `--verbatimModuleSyntax` (5.0) makes import/export intent visible in source: type imports must carry `type`, value imports stay. Combined with `--module preserve` or `--moduleResolution bundler`, the compiler now accurately models what bundlers and modern runtimes actually do. `import defer` (5.9) extends the model to deferred evaluation.
|
||||
|
||||
**Resource lifetimes are first-class.** `using` and `await using` (5.2) provide deterministic cleanup without `try/finally`. Any object implementing `Symbol.dispose` participates. `DisposableStack` handles ad-hoc multi-resource cleanup in functions where creating a full class is overkill.
|
||||
|
||||
**Inference is smarter about what it knows.** Inferred type predicates (5.5) let `.filter(x => x !== undefined)` produce `T[]` instead of `(T | undefined)[]` automatically. `NoInfer<T>` (5.4) gives library authors precise control over which parameters drive inference. Narrowing now survives closures after last assignment, constant indexed accesses, and `switch (true)` patterns.
|
||||
|
||||
**TypeScript 6.0 is a transition release toward 7.0** (the Go-native port). It turns years of soft deprecations into errors and changes several defaults. Most impactful: `types` defaults to `[]` (must list `@types` packages explicitly), `rootDir` defaults to `.`, `strict` defaults to `true`, `module` defaults to `esnext`. Projects relying on implicit behavior need explicit config. Check the deprecations section before upgrading.
|
||||
|
||||
## Replace these patterns
|
||||
|
||||
The left column reflects patterns still common before TypeScript 5.x. Write the right column instead. The "Since" column tells you the minimum TypeScript version required.
|
||||
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}() | [\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
|
||||
## New capabilities
|
||||
|
||||
These enable things that weren't practical before. Reach for them in the described situations.
|
||||
|
||||
| What | Since | When to use it |
|
||||
| ----------------------------------------------- | ------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `using` / `await using` declarations | 5.2 | Any resource needing deterministic cleanup (file handles, DB connections, locks, event listeners). Object must implement `Symbol.dispose` / `Symbol.asyncDispose`. |
|
||||
| `DisposableStack` / `AsyncDisposableStack` | 5.2 | Ad-hoc multi-resource cleanup without creating a class. Call `.defer(fn)` right after acquiring each resource. Stack disposes in LIFO order. |
|
||||
| `const` modifier on type parameters | 5.0 | Force `const`-like (literal/readonly tuple) inference at call sites without requiring callers to write `as const`. Constraint must use `readonly` arrays. |
|
||||
| Decorator metadata (`Symbol.metadata`) | 5.2 | Attach and read per-class metadata from decorators via `context.metadata`. Retrieved as `MyClass[Symbol.metadata]`. Requires `Symbol.metadata ??= Symbol(...)` polyfill. |
|
||||
| `NoInfer<T>` utility type | 5.4 | Prevent a parameter from contributing inference candidates for `T`. Use when one argument should be the "source of truth" and others should only be checked against it. |
|
||||
| Inferred type predicates | 5.5 | Filter callbacks that test for `!== null` or `instanceof` now automatically produce a type predicate. `Array.prototype.filter` then narrows the result array type. |
|
||||
| `--isolatedDeclarations` | 5.5 | Require explicit return types on exported declarations. Unlocks parallel declaration emit by external tooling (esbuild, oxc, etc.) without needing a full type-checker pass. |
|
||||
| `${configDir}` in tsconfig paths | 5.5 | Anchor `typeRoots`, `paths`, `outDir`, etc. in a shared base tsconfig to the _consuming_ project's directory, not the shared file's location. |
|
||||
| Always-truthy/nullish check errors | 5.6 | Catches regex literals in `if`, arrow functions as comparators, `?? 100` on non-nullable left side, misplaced parentheses. No API to call; existing bugs now surface as errors. |
|
||||
| Iterator helper methods (`IteratorObject`) | 5.6 | Built-in iterators from `Map`, `Set`, generators, etc. now have `.map()`, `.filter()`, `.take()`, `.drop()`, `.flatMap()`, `.toArray()`, `.reduce()`, etc. Use `Iterator.from(iterable)` to wrap any iterable. |
|
||||
| `--noUncheckedSideEffectImports` | 5.6 | Error when a side-effect import (`import "..."`) resolves to nothing. Catches typos in polyfill or CSS imports. |
|
||||
| `--noCheck` | 5.6 | Skip type checking entirely during emit. Useful for separating "fast emit" from "thorough check" pipeline stages, especially with `--isolatedDeclarations`. |
|
||||
| `--rewriteRelativeImportExtensions` | 5.7 | Rewrite `.ts`→`.js`, `.tsx`→`.jsx`, `.mts`→`.mjs`, `.cts`→`.cjs` in relative imports during emit. Required when writing `.ts` imports for Node.js strip-types mode and still needing `.js` output for library distribution. |
|
||||
| `--erasableSyntaxOnly` | 5.8 | Error on constructs that can't be type-stripped by Node.js `--experimental-strip-types`: `enum`, `namespace` with code, parameter properties, `import =` aliases. |
|
||||
| `require()` of ESM under `--module nodenext` | 5.8 | Node.js 22+ allows CJS to `require()` ESM files (no top-level `await`). TypeScript now allows this under `nodenext` without error. |
|
||||
| `import defer * as ns from "..."` | 5.9 | Defer module _evaluation_ (not loading) until first property access. Module is loaded and verified at import time; side-effects are delayed. Only works with `--module preserve` or `esnext`. |
|
||||
| `Set` algebra methods | 5.5 | Non-mutating: `union`, `intersection`, `difference`, `symmetricDifference` → new `Set`. Predicate: `isSubsetOf`, `isSupersetOf`, `isDisjointFrom` → `boolean`. Requires `esnext` or `es2025` lib. |
|
||||
| `Object.groupBy` / `Map.groupBy` | 5.4 | Group an iterable into buckets by key function. Return type has all keys as optional (not every key is guaranteed present). Requires `esnext` or `es2024`+ lib. |
|
||||
| `Temporal` API types | 6.0 RC | `Temporal.Now`, `Temporal.Instant`, `Temporal.PlainDate`, etc. Available under `esnext` or `esnext.temporal` lib. Usable in runtimes that already ship it (V8 118+, SpiderMonkey, etc.). |
|
||||
| `@satisfies` in JSDoc | 5.0 | Validates that a JS expression satisfies a type without widening it — the TS `satisfies` operator for `.js` files. Write `/** @satisfies {MyType} */` above the declaration or inline on a parenthesized expression. |
|
||||
| `@overload` in JSDoc | 5.0 | Declare multiple call signatures for a JS function. Each JSDoc comment tagged `@overload` is treated as a distinct overload; the final JSDoc comment (without `@overload`) describes the implementation signature. |
|
||||
| Getter/setter with completely unrelated types | 5.1 | `get style(): CSSStyleDeclaration` and `set style(v: string)` can now have fully unrelated types, provided both have explicit type annotations. Previously the getter type was required to be a subtype of the setter type. |
|
||||
| `instanceof` narrowing via `Symbol.hasInstance` | 5.3 | When a class defines `static [Symbol.hasInstance](val: unknown): val is T`, the `instanceof` operator now narrows to the predicate type `T`, not the class type itself. Useful when the runtime check and the structural type differ. |
|
||||
| Regex literal syntax checking | 5.5 | TypeScript validates regex literal syntax: malformed groups, nonexistent backreferences, named capture mismatches, and features not available at the current `--target`. No API needed; existing latent bugs surface as errors automatically. |
|
||||
| `--build` continues past intermediate errors | 5.6 | `tsc --build` no longer stops at the first failing project. All projects are built and errors reported together. Use `--stopOnBuildErrors` to restore the old stop-on-first-error behavior. Useful for monorepos during upgrades. |
|
||||
| `--module node18` | 5.8 | Stable `--module` flag for Node.js 18 semantics: disallows `require()` of ESM (unlike `nodenext`) and still allows import assertions. Use when pinned to Node 18 and not ready for `nodenext` behavior changes. |
|
||||
| `--module node20` | 5.9 | Stable `--module` flag for Node.js 20 semantics: permits `require()` of ESM, rejects import assertions. Implies `--target es2023` (unlike `nodenext`, which floats to `esnext`). |
|
||||
|
||||
## Key APIs
|
||||
|
||||
### `Disposable` / `AsyncDisposable` / stacks (5.2)
|
||||
|
||||
Global types provided by TypeScript's lib (requires `esnext.disposable` or `esnext` in `lib`):
|
||||
|
||||
- `Disposable` — `{ [Symbol.dispose](): void }`
|
||||
- `AsyncDisposable` — `{ [Symbol.asyncDispose](): PromiseLike<void> }`
|
||||
- `DisposableStack` — `defer(fn)`, `use(resource)`, `adopt(value, disposeFn)`, `move()`. Is itself `Disposable`.
|
||||
- `AsyncDisposableStack` — async equivalent. Is itself `AsyncDisposable`.
|
||||
- `SuppressedError` — thrown when both the scope body and a `[Symbol.dispose]` throw. `.error` holds the dispose-phase error; `.suppressed` holds the original error.
|
||||
|
||||
Polyfill the symbols in older runtimes:
|
||||
|
||||
```ts
|
||||
Symbol.dispose ??= Symbol("Symbol.dispose");
|
||||
Symbol.asyncDispose ??= Symbol("Symbol.asyncDispose");
|
||||
```
|
||||
|
||||
### Decorator context types (5.0)
|
||||
|
||||
Each decorator kind receives a typed context object as its second parameter:
|
||||
|
||||
- `ClassDecoratorContext`
|
||||
- `ClassMethodDecoratorContext`
|
||||
- `ClassGetterDecoratorContext`
|
||||
- `ClassSetterDecoratorContext`
|
||||
- `ClassFieldDecoratorContext`
|
||||
- `ClassAccessorDecoratorContext`
|
||||
|
||||
All context objects have `.name`, `.kind`, `.static`, `.private`, and `.metadata`. Method/getter/setter/accessor contexts also have `.addInitializer(fn)` for running code at construction time.
|
||||
|
||||
### `IteratorObject` (5.6)
|
||||
|
||||
`IteratorObject<T, TReturn, TNext>` is the new type for built-in iterable iterators. Key methods: `map`, `filter`, `take`, `drop`, `flatMap`, `forEach`, `reduce`, `some`, `every`, `find`, `toArray`. Not the same as the pre-existing structural `Iterator<T>` protocol.
|
||||
|
||||
- Generators produce `Generator<T>` which extends `IteratorObject`.
|
||||
- `Map.prototype.entries()` returns `MapIterator<[K, V]>`, `Set.prototype.values()` returns `SetIterator<T>`, etc.
|
||||
- `Iterator.from(iterable)` converts any `Iterable` to an `IteratorObject`.
|
||||
- `AsyncIteratorObject` exists for async parity.
|
||||
- `--strictBuiltinIteratorReturn` (new `--strict`-mode flag in 5.6) makes the return type of `BuiltinIteratorReturn` be `undefined` instead of `any`, catching unchecked `done` access.
|
||||
|
||||
### Array copying methods (5.2)
|
||||
|
||||
Declared on `Array`, `ReadonlyArray`, and all `TypedArray` types. Use these instead of the mutating variants when you need to preserve the original:
|
||||
|
||||
| Mutating | Non-mutating copy |
|
||||
| ---------------------------------- | ------------------------------------- |
|
||||
| `arr.sort(cmp)` | `arr.toSorted(cmp)` |
|
||||
| `arr.reverse()` | `arr.toReversed()` |
|
||||
| `arr.splice(start, del, ...items)` | `arr.toSpliced(start, del, ...items)` |
|
||||
| `arr[i] = v` | `arr.with(i, v)` |
|
||||
|
||||
## Pitfalls
|
||||
|
||||
Things easy to get wrong even when you know the modern API exists. Check your output against these.
|
||||
|
||||
**tsconfig defaults changed hard in 6.0.** `types: []` means no `@types/*` packages load implicitly. If you see floods of "cannot find name 'process'" or "cannot find module 'fs'" after upgrading to 6.0, add `"types": ["node"]` (or whatever you need) to `compilerOptions`. `rootDir: "."` means a project with source in `src/` will emit to `dist/src/` instead of `dist/` — add `"rootDir": "./src"` explicitly. `strict: true` by default means projects with loose code see new errors.
|
||||
|
||||
**`using` requires a runtime polyfill on older runtimes.** `Symbol.dispose` and `Symbol.asyncDispose` don't exist before Node.js 18.x / Chrome 120. Add the two-line polyfill at your entry point. `DisposableStack` and `AsyncDisposableStack` need a more substantial polyfill (e.g. from `@microsoft/using-polyfill`).
|
||||
|
||||
**`using` disposes in LIFO order.** Resources declared later in a scope are disposed first. Declare in the order you want reversed cleanup (acquisition order). `DisposableStack.defer` also runs in LIFO order.
|
||||
|
||||
**Inferred type predicates have if-and-only-if semantics.** `x => !!x` does NOT infer `x is NonNullable<T>` because `0`, `""`, and `false` are falsy but not absent. TypeScript correctly refuses the predicate. Use `x => x !== undefined` or `x => x !== null` for precise null/undefined filters. If a predicate isn't being inferred, the false branch is probably ambiguous.
|
||||
|
||||
**`--verbatimModuleSyntax` breaks CJS `require` emit.** Under this flag ESM `import`/`export` is emitted verbatim. You cannot produce `require()` calls from standard `import` syntax. For CJS output you must use `import foo = require("foo")` and `export = { ... }` syntax explicitly.
|
||||
|
||||
**`NoInfer<T>` doesn't prevent `T` from being resolved, only from being contributed at that position.** Other parameters can still infer `T`. It means "don't use me as an inference candidate", not "block `T` from being resolved".
|
||||
|
||||
**`--isolatedDeclarations` requires explicit return types on all exports.** Exported arrow functions, function declarations, and class methods all need annotations if their return type isn't trivially inferrable from a literal or type assertion. Editor quick-fixes can add them automatically.
|
||||
|
||||
**Standard decorators are incompatible with `--experimentalDecorators`.** Different type signatures, metadata model, and emit. A decorator written for one will not work with the other. `--emitDecoratorMetadata` is not supported with standard decorators. Don't mix the two systems in one project.
|
||||
|
||||
**`import defer` does not downlevel.** TypeScript does not transform `import defer` to polyfill-compatible code. The module is still _loaded_ eagerly (must exist); only _evaluation_ is deferred. Only use it under `--module preserve` or `esnext` with a runtime or bundler that supports it.
|
||||
|
||||
**`--erasableSyntaxOnly` prohibits parameter properties.** `constructor(public x: number)` is not allowed. Expand to an explicit field declaration plus assignment in the constructor body.
|
||||
|
||||
**Closure narrowing is invalidated if the variable is assigned anywhere in a nested function.** TypeScript cannot know when a nested function will run, so any assignment to a `let`/param inside a nested function — even a no-op like `value = value` — invalidates narrowing for all closures in the outer scope. Only the outer "no further assignments after this point" pattern is safe.
|
||||
|
||||
**Constant indexed access narrowing requires both `obj` and `key` to be unmodified between the check and the use.** If either is a `let` that could be reassigned, TypeScript will not narrow `obj[key]`. Extract the value to a `const` in that case.
|
||||
|
||||
**`switch (true)` narrowing does not carry across fall-through cases.** In a `switch (true)`, each `case` condition narrows independently. A variable narrowed in `case typeof x === "string":` that falls through to the next case will have its narrowing widened by the next condition, not accumulated from the previous one.
|
||||
|
||||
**`const` type parameter modifier falls back when constraint is mutable.** `<const T extends string[]>(args: T)` falls back to `string[]` because `readonly ["a", "b"]` isn't assignable to `string[]`. Use `<const T extends readonly string[]>` for arrays.
|
||||
|
||||
**`assert` import syntax errors under `--module nodenext` since 5.8.** Any remaining `import x from "..." assert { ... }` must be updated to `import x from "..." with { ... }`.
|
||||
|
||||
**`Array.prototype.filter(x => x !== null)` now narrows to non-null (5.5).** This is almost always correct, but if you intentionally needed the nullable type downstream, add an explicit annotation: `const items: (T | null)[] = arr.filter(x => x !== null)`.
|
||||
|
||||
## Behavioral changes that affect code
|
||||
|
||||
- **All enums are union enums** (5.0): Every enum member gets its own literal type. Out-of-domain literal assignment to an enum type now errors. Cross-enum assignment between enums with identical names but differing values now errors.
|
||||
- **Relational operators no longer allow implicit string/number coercions** (5.0): `ns > 4` where `ns: number | string` is a type error. Use `+ns > 4` to explicitly coerce.
|
||||
- **`--module`/`--moduleResolution` must agree on node flavor** (5.2): Mixing `--module nodenext` with `--moduleResolution bundler` is an error. Use `--module nodenext` alone or `--module esnext --moduleResolution bundler`.
|
||||
- **Deprecations from 5.0 become hard errors in 5.5**: `--importsNotUsedAsValues`, `--preserveValueImports`, `--target ES3`, `--out`, and several others are fully removed in 5.5. They can no longer be specified, even with `"ignoreDeprecations": "5.0"`. Migrate to `--verbatimModuleSyntax` for the import flags.
|
||||
- **Type-only imports conflicting with local values** (5.4): Under `--isolatedModules`, `import { Foo } from "..."` where a local `let Foo` also exists now errors. Use `import type { Foo }` or `import { type Foo }`.
|
||||
- **Reference directives no longer synthesized or preserved in declaration emit** (5.5): `/// <reference types="node" />` TypeScript used to add automatically is no longer emitted. User-written directives are dropped unless they carry `preserve="true"`. Update library `tsconfig.json` if you relied on this.
|
||||
- **`.mts` files never emit CJS; `.cts` files never emit ESM** (5.6): Regardless of `--module` setting. Previously the extension was ignored in some modes.
|
||||
- **JSON imports under `--module nodenext` require `with { type: "json" }`** (5.7): `import data from "./config.json"` without the attribute is now a type error.
|
||||
- **`TypedArray`s are now generic** (5.7): `Uint8Array` is `Uint8Array<TArrayBuffer extends ArrayBufferLike = ArrayBufferLike>`. Code passing `Buffer` (from `@types/node`) to typed-array parameters may see new errors. Update `@types/node` to a version that matches.
|
||||
- **`import assert { ... }` is an error under `--module nodenext`** (5.8): Node.js 22 dropped support for the old syntax. Use `with { ... }`.
|
||||
- **`types` defaults to `[]` in 6.0**: All implicit `@types/*` loading stops. Add an explicit `"types": ["node"]` or the array will remain empty. Using `"types": ["*"]` restores the 5.x behavior.
|
||||
- **`rootDir` defaults to `.` (the tsconfig directory) in 6.0**: Previously inferred from the common ancestor of all source files. Projects with `"include": ["./src"]` and no explicit `rootDir` will now emit into `dist/src/` instead of `dist/`. Add `"rootDir": "./src"` to fix.
|
||||
- **`strict` defaults to `true` in 6.0**: Projects that were implicitly not strict will see new errors. Set `"strict": false` explicitly if you're not ready to fix them.
|
||||
- **`--baseUrl` deprecated in 6.0** and no longer acts as a module resolution root. Add explicit prefixes to your `paths` entries instead.
|
||||
- **`--moduleResolution node` (node10) deprecated in 6.0**: Removed in 7.0. Migrate to `nodenext` or `bundler`.
|
||||
- **`amd`, `umd`, `systemjs`, `none` module targets deprecated in 6.0**: Removed in 7.0. Migrate to a bundler.
|
||||
- **`--outFile` removed in 6.0**: Use a bundler (esbuild, Rollup, Webpack, etc.).
|
||||
- **`module Foo { }` syntax removed in 6.0**: Rename all such declarations to `namespace Foo { }`.
|
||||
- **`--esModuleInterop false` and `--allowSyntheticDefaultImports false` removed in 6.0**: Safe interop is now always on. Default imports from CJS modules (`import express from "express"`) are always valid.
|
||||
- **Explicit `typeRoots` disables upward `node_modules/@types` fallback** (5.1): When `typeRoots` is specified and a lookup fails in those directories, TypeScript no longer walks parent directories for `@types`. If you relied on the fallback, add `"./node_modules/@types"` explicitly to your `typeRoots` array.
|
||||
- **`super.` on instance field properties is a type error** (5.3): Calling `super.foo()` where `foo` is a class field (arrow function assigned in the constructor) rather than a prototype method now errors. Instance fields don't exist on the prototype; `super.field` is `undefined` at runtime.
|
||||
- **`--build` always emits `.tsbuildinfo`** (5.6): Previously only written when `--incremental` or `--composite` was set. Now written unconditionally in any `--build` invocation. Update `.gitignore` or CI artifact management if needed.
|
||||
- **`.mts`/`.cts` extensions and `package.json` `"type"` respected in all module modes** (5.6): Format-specific extensions and the `"type"` field inside `node_modules` are now honored regardless of `--module` setting (except `amd`, `umd`, `system`). A `.mts` file will never emit CJS output even under `--module commonjs`.
|
||||
- **Granular return expression checking** (5.8): Each branch of a conditional expression (`cond ? a : b`) directly inside a `return` statement is now checked individually against the declared return type. Previously an `any`-typed branch could silently suppress type errors in the other branch.
|
||||
@@ -5,6 +5,6 @@ runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install syft
|
||||
uses: anchore/sbom-action/download-syft@f325610c9f50a54015d37c8d16cb3b0e2c8f4de0 # v0.18.0
|
||||
uses: anchore/sbom-action/download-syft@e22c389904149dbc22b58101806040fa8d37a610 # v0.24.0
|
||||
with:
|
||||
syft-version: "v1.20.0"
|
||||
syft-version: "v1.26.1"
|
||||
|
||||
@@ -4,7 +4,7 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.25.7"
|
||||
default: "1.25.8"
|
||||
use-cache:
|
||||
description: "Whether to use the cache."
|
||||
default: "true"
|
||||
|
||||
@@ -82,9 +82,6 @@ updates:
|
||||
mui:
|
||||
patterns:
|
||||
- "@mui*"
|
||||
radix:
|
||||
patterns:
|
||||
- "@radix-ui/*"
|
||||
react:
|
||||
patterns:
|
||||
- "react"
|
||||
|
||||
+29
-101
@@ -181,7 +181,7 @@ jobs:
|
||||
echo "LINT_CACHE_DIR=$dir" >> "$GITHUB_ENV"
|
||||
|
||||
- name: golangci-lint cache
|
||||
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
|
||||
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: |
|
||||
${{ env.LINT_CACHE_DIR }}
|
||||
@@ -204,7 +204,7 @@ jobs:
|
||||
|
||||
# Needed for helm chart linting
|
||||
- name: Install helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # v4.3.1
|
||||
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # v5.0.0
|
||||
with:
|
||||
version: v3.9.2
|
||||
continue-on-error: true
|
||||
@@ -870,7 +870,7 @@ jobs:
|
||||
# the check to pass. This is desired in PRs, but not in mainline.
|
||||
- name: Publish to Chromatic (non-mainline)
|
||||
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@07791f8243f4cb2698bf4d00426baf4b2d1cb7e0 # v13.3.5
|
||||
uses: chromaui/action@f191a0224b10e1a38b2091cefb7b7a2337009116 # v16.0.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -902,7 +902,7 @@ jobs:
|
||||
# infinitely "in progress" in mainline unless we re-review each build.
|
||||
- name: Publish to Chromatic (mainline)
|
||||
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@07791f8243f4cb2698bf4d00426baf4b2d1cb7e0 # v13.3.5
|
||||
uses: chromaui/action@f191a0224b10e1a38b2091cefb7b7a2337009116 # v16.0.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -1316,122 +1316,50 @@ jobs:
|
||||
"${IMAGE}"
|
||||
done
|
||||
|
||||
# GitHub attestation provides SLSA provenance for the Docker images, establishing a verifiable
|
||||
# record that these images were built in GitHub Actions with specific inputs and environment.
|
||||
# This complements our existing cosign attestations which focus on SBOMs.
|
||||
#
|
||||
# We attest each tag separately to ensure all tags have proper provenance records.
|
||||
# TODO: Consider refactoring these steps to use a matrix strategy or composite action to reduce duplication
|
||||
# while maintaining the required functionality for each tag.
|
||||
- name: Resolve Docker image digests for attestation
|
||||
id: docker_digests
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
env:
|
||||
IMAGE_BASE: ghcr.io/coder/coder-preview
|
||||
BUILD_TAG: ${{ steps.build-docker.outputs.tag }}
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
main_digest=$(docker buildx imagetools inspect --raw "${IMAGE_BASE}:main" | sha256sum | awk '{print "sha256:"$1}')
|
||||
echo "main_digest=${main_digest}" >> "$GITHUB_OUTPUT"
|
||||
latest_digest=$(docker buildx imagetools inspect --raw "${IMAGE_BASE}:latest" | sha256sum | awk '{print "sha256:"$1}')
|
||||
echo "latest_digest=${latest_digest}" >> "$GITHUB_OUTPUT"
|
||||
version_digest=$(docker buildx imagetools inspect --raw "${IMAGE_BASE}:${BUILD_TAG}" | sha256sum | awk '{print "sha256:"$1}')
|
||||
echo "version_digest=${version_digest}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: GitHub Attestation for Docker image
|
||||
id: attest_main
|
||||
if: github.ref == 'refs/heads/main'
|
||||
if: github.ref == 'refs/heads/main' && steps.docker_digests.outputs.main_digest != ''
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:main"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
predicate: |
|
||||
{
|
||||
"buildType": "https://github.com/actions/runner-images/",
|
||||
"builder": {
|
||||
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
},
|
||||
"invocation": {
|
||||
"configSource": {
|
||||
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
|
||||
"digest": {
|
||||
"sha1": "${{ github.sha }}"
|
||||
},
|
||||
"entryPoint": ".github/workflows/ci.yaml"
|
||||
},
|
||||
"environment": {
|
||||
"github_workflow": "${{ github.workflow }}",
|
||||
"github_run_id": "${{ github.run_id }}"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"buildInvocationID": "${{ github.run_id }}",
|
||||
"completeness": {
|
||||
"environment": true,
|
||||
"materials": true
|
||||
}
|
||||
}
|
||||
}
|
||||
subject-name: ghcr.io/coder/coder-preview
|
||||
subject-digest: ${{ steps.docker_digests.outputs.main_digest }}
|
||||
push-to-registry: true
|
||||
|
||||
- name: GitHub Attestation for Docker image (latest tag)
|
||||
id: attest_latest
|
||||
if: github.ref == 'refs/heads/main'
|
||||
if: github.ref == 'refs/heads/main' && steps.docker_digests.outputs.latest_digest != ''
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:latest"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
predicate: |
|
||||
{
|
||||
"buildType": "https://github.com/actions/runner-images/",
|
||||
"builder": {
|
||||
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
},
|
||||
"invocation": {
|
||||
"configSource": {
|
||||
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
|
||||
"digest": {
|
||||
"sha1": "${{ github.sha }}"
|
||||
},
|
||||
"entryPoint": ".github/workflows/ci.yaml"
|
||||
},
|
||||
"environment": {
|
||||
"github_workflow": "${{ github.workflow }}",
|
||||
"github_run_id": "${{ github.run_id }}"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"buildInvocationID": "${{ github.run_id }}",
|
||||
"completeness": {
|
||||
"environment": true,
|
||||
"materials": true
|
||||
}
|
||||
}
|
||||
}
|
||||
subject-name: ghcr.io/coder/coder-preview
|
||||
subject-digest: ${{ steps.docker_digests.outputs.latest_digest }}
|
||||
push-to-registry: true
|
||||
|
||||
- name: GitHub Attestation for version-specific Docker image
|
||||
id: attest_version
|
||||
if: github.ref == 'refs/heads/main'
|
||||
if: github.ref == 'refs/heads/main' && steps.docker_digests.outputs.version_digest != ''
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
predicate: |
|
||||
{
|
||||
"buildType": "https://github.com/actions/runner-images/",
|
||||
"builder": {
|
||||
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
},
|
||||
"invocation": {
|
||||
"configSource": {
|
||||
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
|
||||
"digest": {
|
||||
"sha1": "${{ github.sha }}"
|
||||
},
|
||||
"entryPoint": ".github/workflows/ci.yaml"
|
||||
},
|
||||
"environment": {
|
||||
"github_workflow": "${{ github.workflow }}",
|
||||
"github_run_id": "${{ github.run_id }}"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"buildInvocationID": "${{ github.run_id }}",
|
||||
"completeness": {
|
||||
"environment": true,
|
||||
"materials": true
|
||||
}
|
||||
}
|
||||
}
|
||||
subject-name: ghcr.io/coder/coder-preview
|
||||
subject-digest: ${{ steps.docker_digests.outputs.version_digest }}
|
||||
push-to-registry: true
|
||||
|
||||
# Report attestation failures but don't fail the workflow
|
||||
|
||||
@@ -95,7 +95,7 @@ jobs:
|
||||
AWS_DOGFOOD_DEPLOY_REGION: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
|
||||
|
||||
- name: Set up Flux CLI
|
||||
uses: fluxcd/flux2/action@8454b02a32e48d775b9f563cb51fdcb1787b5b93 # v2.7.5
|
||||
uses: fluxcd/flux2/action@871be9b40d53627786d3a3835a3ddba1e3234bd2 # v2.8.3
|
||||
with:
|
||||
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
|
||||
version: "2.8.2"
|
||||
|
||||
@@ -240,6 +240,7 @@ jobs:
|
||||
- name: Create Coder Task for Documentation Check
|
||||
if: steps.check-secrets.outputs.skip != 'true'
|
||||
id: create_task
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/create-task-action
|
||||
with:
|
||||
coder-url: ${{ secrets.DOC_CHECK_CODER_URL }}
|
||||
@@ -254,8 +255,21 @@ jobs:
|
||||
github-issue-url: ${{ steps.determine-context.outputs.pr_url }}
|
||||
comment-on-issue: false
|
||||
|
||||
- name: Handle Task Creation Failure
|
||||
if: steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome != 'success'
|
||||
run: |
|
||||
{
|
||||
echo "## Documentation Check Task"
|
||||
echo ""
|
||||
echo "⚠️ The external Coder task service was unavailable, so this"
|
||||
echo "advisory documentation check did not run."
|
||||
echo ""
|
||||
echo "Maintainers can rerun the workflow or trigger it manually"
|
||||
echo "after the service recovers."
|
||||
} >> "${GITHUB_STEP_SUMMARY}"
|
||||
|
||||
- name: Write Task Info
|
||||
if: steps.check-secrets.outputs.skip != 'true'
|
||||
if: steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
|
||||
env:
|
||||
TASK_CREATED: ${{ steps.create_task.outputs.task-created }}
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
@@ -273,7 +287,7 @@ jobs:
|
||||
} >> "${GITHUB_STEP_SUMMARY}"
|
||||
|
||||
- name: Wait for Task Completion
|
||||
if: steps.check-secrets.outputs.skip != 'true'
|
||||
if: steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
|
||||
id: wait_task
|
||||
env:
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
@@ -363,7 +377,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Fetch Task Logs
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true'
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
|
||||
env:
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
run: |
|
||||
@@ -376,7 +390,7 @@ jobs:
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Cleanup Task
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true'
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true' && steps.create_task.outcome == 'success'
|
||||
env:
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
run: |
|
||||
@@ -390,6 +404,7 @@ jobs:
|
||||
- name: Write Final Summary
|
||||
if: always() && steps.check-secrets.outputs.skip != 'true'
|
||||
env:
|
||||
CREATE_TASK_OUTCOME: ${{ steps.create_task.outcome }}
|
||||
TASK_NAME: ${{ steps.create_task.outputs.task-name }}
|
||||
TASK_MESSAGE: ${{ steps.wait_task.outputs.task_message }}
|
||||
RESULT_URI: ${{ steps.wait_task.outputs.result_uri }}
|
||||
@@ -400,10 +415,15 @@ jobs:
|
||||
echo "---"
|
||||
echo "### Result"
|
||||
echo ""
|
||||
echo "**Status:** ${TASK_MESSAGE:-Task completed}"
|
||||
if [[ -n "${RESULT_URI}" ]]; then
|
||||
echo "**Comment:** ${RESULT_URI}"
|
||||
if [[ "${CREATE_TASK_OUTCOME}" == "success" ]]; then
|
||||
echo "**Status:** ${TASK_MESSAGE:-Task completed}"
|
||||
if [[ -n "${RESULT_URI}" ]]; then
|
||||
echo "**Comment:** ${RESULT_URI}"
|
||||
fi
|
||||
echo ""
|
||||
echo "Task \`${TASK_NAME}\` has been cleaned up."
|
||||
else
|
||||
echo "**Status:** Skipped because the external Coder task"
|
||||
echo "service was unavailable."
|
||||
fi
|
||||
echo ""
|
||||
echo "Task \`${TASK_NAME}\` has been cleaned up."
|
||||
} >> "${GITHUB_STEP_SUMMARY}"
|
||||
|
||||
@@ -4,23 +4,20 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
# This event reads the workflow from the default branch (main), not the
|
||||
# release branch. No cherry-pick needed.
|
||||
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#release
|
||||
release:
|
||||
types: [published]
|
||||
- "release/2.[0-9]+"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
# Queue rather than cancel so back-to-back pushes to main don't cancel the first sync.
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
sync:
|
||||
name: Sync issues to Linear release
|
||||
if: github.event_name == 'push'
|
||||
sync-main:
|
||||
name: Sync issues to next Linear release
|
||||
if: github.event_name == 'push' && github.ref_name == 'main'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
@@ -28,38 +25,86 @@ jobs:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Detect next release version
|
||||
id: version
|
||||
# Find the highest release/2.X branch (exact pattern, no suffixes
|
||||
# like release/2.31_hotfix) and derive the next minor version for
|
||||
# the release currently in development on main.
|
||||
run: |
|
||||
LATEST_MINOR=$(git branch -r | grep -E '^\s*origin/release/2\.[0-9]+$' | \
|
||||
sed 's/.*release\/2\.//' | sort -n | tail -1)
|
||||
if [ -z "$LATEST_MINOR" ]; then
|
||||
echo "No release branch found, skipping sync."
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
NEXT="2.$((LATEST_MINOR + 1))"
|
||||
echo "version=$NEXT" >> "$GITHUB_OUTPUT"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Detected next release: $NEXT"
|
||||
|
||||
- name: Sync issues
|
||||
id: sync
|
||||
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0.5.0
|
||||
if: steps.version.outputs.skip != 'true'
|
||||
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: sync
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
name: ${{ steps.version.outputs.version }}
|
||||
timeout: 300
|
||||
|
||||
- name: Print release URL
|
||||
if: steps.sync.outputs.release-url
|
||||
run: echo "Synced to $RELEASE_URL"
|
||||
env:
|
||||
RELEASE_URL: ${{ steps.sync.outputs.release-url }}
|
||||
sync-release-branch:
|
||||
name: Sync backports to Linear release
|
||||
if: github.event_name == 'push' && startsWith(github.ref_name, 'release/')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
complete:
|
||||
name: Complete Linear release
|
||||
if: github.event_name == 'release'
|
||||
- name: Extract release version
|
||||
id: version
|
||||
# The trigger only allows exact release/2.X branch names.
|
||||
run: |
|
||||
echo "version=${GITHUB_REF_NAME#release/}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Sync issues
|
||||
id: sync
|
||||
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: sync
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
name: ${{ steps.version.outputs.version }}
|
||||
timeout: 300
|
||||
|
||||
code-freeze:
|
||||
name: Move Linear release to Code Freeze
|
||||
needs: sync-release-branch
|
||||
if: >
|
||||
github.event_name == 'push' &&
|
||||
startsWith(github.ref_name, 'release/') &&
|
||||
github.event.created == true
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Complete release
|
||||
id: complete
|
||||
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0
|
||||
- name: Extract release version
|
||||
id: version
|
||||
run: |
|
||||
echo "version=${GITHUB_REF_NAME#release/}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Move to Code Freeze
|
||||
id: update
|
||||
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: complete
|
||||
version: ${{ github.event.release.tag_name }}
|
||||
command: update
|
||||
stage: Code Freeze
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
timeout: 300
|
||||
|
||||
- name: Print release URL
|
||||
if: steps.complete.outputs.release-url
|
||||
run: echo "Completed $RELEASE_URL"
|
||||
env:
|
||||
RELEASE_URL: ${{ steps.complete.outputs.release-url }}
|
||||
|
||||
+84
-117
@@ -9,6 +9,7 @@ on:
|
||||
options:
|
||||
- mainline
|
||||
- stable
|
||||
- rc
|
||||
release_notes:
|
||||
description: Release notes for the publishing the release. This is required to create a release.
|
||||
dry_run:
|
||||
@@ -119,9 +120,19 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 2.10.2 -> release/2.10
|
||||
# Derive the release branch from the version tag.
|
||||
# Standard: 2.10.2 -> release/2.10
|
||||
# RC: 2.32.0-rc.0 -> release/2.32-rc.0
|
||||
version="$(./scripts/version.sh)"
|
||||
release_branch=release/${version%.*}
|
||||
if [[ "$version" == *-rc.* ]]; then
|
||||
# Extract major.minor and rc suffix from e.g. 2.32.0-rc.0
|
||||
base_version="${version%%-rc.*}" # 2.32.0
|
||||
major_minor="${base_version%.*}" # 2.32
|
||||
rc_suffix="${version##*-rc.}" # 0
|
||||
release_branch="release/${major_minor}-rc.${rc_suffix}"
|
||||
else
|
||||
release_branch=release/${version%.*}
|
||||
fi
|
||||
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
|
||||
if [[ -z "${branch_contains_tag}" ]]; then
|
||||
echo "Ref tag must exist in a branch named ${release_branch} when creating a release, did you use scripts/release.sh?"
|
||||
@@ -302,6 +313,7 @@ jobs:
|
||||
|
||||
# This uses OIDC authentication, so no auth variables are required.
|
||||
- name: Build base Docker image via depot.dev
|
||||
id: build_base_image
|
||||
if: steps.image-base-tag.outputs.tag != ''
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
@@ -349,48 +361,14 @@ jobs:
|
||||
env:
|
||||
IMAGE_TAG: ${{ steps.image-base-tag.outputs.tag }}
|
||||
|
||||
# GitHub attestation provides SLSA provenance for Docker images, establishing a verifiable
|
||||
# record that these images were built in GitHub Actions with specific inputs and environment.
|
||||
# This complements our existing cosign attestations (which focus on SBOMs) by adding
|
||||
# GitHub-specific build provenance to enhance our supply chain security.
|
||||
#
|
||||
# TODO: Consider refactoring these attestation steps to use a matrix strategy or composite action
|
||||
# to reduce duplication while maintaining the required functionality for each distinct image tag.
|
||||
- name: GitHub Attestation for Base Docker image
|
||||
id: attest_base
|
||||
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
|
||||
if: ${{ !inputs.dry_run && steps.build_base_image.outputs.digest != '' }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
with:
|
||||
subject-name: ${{ steps.image-base-tag.outputs.tag }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
predicate: |
|
||||
{
|
||||
"buildType": "https://github.com/actions/runner-images/",
|
||||
"builder": {
|
||||
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
},
|
||||
"invocation": {
|
||||
"configSource": {
|
||||
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
|
||||
"digest": {
|
||||
"sha1": "${{ github.sha }}"
|
||||
},
|
||||
"entryPoint": ".github/workflows/release.yaml"
|
||||
},
|
||||
"environment": {
|
||||
"github_workflow": "${{ github.workflow }}",
|
||||
"github_run_id": "${{ github.run_id }}"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"buildInvocationID": "${{ github.run_id }}",
|
||||
"completeness": {
|
||||
"environment": true,
|
||||
"materials": true
|
||||
}
|
||||
}
|
||||
}
|
||||
subject-name: ghcr.io/coder/coder-base
|
||||
subject-digest: ${{ steps.build_base_image.outputs.digest }}
|
||||
push-to-registry: true
|
||||
|
||||
- name: Build Linux Docker images
|
||||
@@ -413,7 +391,6 @@ jobs:
|
||||
# being pushed so will automatically push them.
|
||||
make push/build/coder_"$version"_linux.tag
|
||||
|
||||
# Save multiarch image tag for attestation
|
||||
multiarch_image="$(./scripts/image_tag.sh)"
|
||||
echo "multiarch_image=${multiarch_image}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
@@ -424,12 +401,14 @@ jobs:
|
||||
# version in the repo, also create a multi-arch image as ":latest" and
|
||||
# push it
|
||||
if [[ "$(git tag | grep '^v' | grep -vE '(rc|dev|-|\+|\/)' | sort -r --version-sort | head -n1)" == "v$(./scripts/version.sh)" ]]; then
|
||||
latest_target="$(./scripts/image_tag.sh --version latest)"
|
||||
# shellcheck disable=SC2046
|
||||
./scripts/build_docker_multiarch.sh \
|
||||
--push \
|
||||
--target "$(./scripts/image_tag.sh --version latest)" \
|
||||
--target "${latest_target}" \
|
||||
$(cat build/coder_"$version"_linux_{amd64,arm64,armv7}.tag)
|
||||
echo "created_latest_tag=true" >> "$GITHUB_OUTPUT"
|
||||
echo "latest_target=${latest_target}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "created_latest_tag=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
@@ -450,7 +429,6 @@ jobs:
|
||||
echo "Generating SBOM for multi-arch image: ${MULTIARCH_IMAGE}"
|
||||
syft "${MULTIARCH_IMAGE}" -o spdx-json > "coder_${VERSION}_sbom.spdx.json"
|
||||
|
||||
# Attest SBOM to multi-arch image
|
||||
echo "Attesting SBOM to multi-arch image: ${MULTIARCH_IMAGE}"
|
||||
cosign clean --force=true "${MULTIARCH_IMAGE}"
|
||||
cosign attest --type spdxjson \
|
||||
@@ -472,85 +450,42 @@ jobs:
|
||||
"${latest_tag}"
|
||||
fi
|
||||
|
||||
- name: GitHub Attestation for Docker image
|
||||
id: attest_main
|
||||
- name: Resolve Docker image digests for attestation
|
||||
id: docker_digests
|
||||
if: ${{ !inputs.dry_run }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
with:
|
||||
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
predicate: |
|
||||
{
|
||||
"buildType": "https://github.com/actions/runner-images/",
|
||||
"builder": {
|
||||
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
},
|
||||
"invocation": {
|
||||
"configSource": {
|
||||
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
|
||||
"digest": {
|
||||
"sha1": "${{ github.sha }}"
|
||||
},
|
||||
"entryPoint": ".github/workflows/release.yaml"
|
||||
},
|
||||
"environment": {
|
||||
"github_workflow": "${{ github.workflow }}",
|
||||
"github_run_id": "${{ github.run_id }}"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"buildInvocationID": "${{ github.run_id }}",
|
||||
"completeness": {
|
||||
"environment": true,
|
||||
"materials": true
|
||||
}
|
||||
}
|
||||
}
|
||||
push-to-registry: true
|
||||
env:
|
||||
MULTIARCH_IMAGE: ${{ steps.build_docker.outputs.multiarch_image }}
|
||||
LATEST_TARGET: ${{ steps.build_docker.outputs.latest_target }}
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
if [[ -n "${MULTIARCH_IMAGE}" ]]; then
|
||||
multiarch_digest=$(docker buildx imagetools inspect --raw "${MULTIARCH_IMAGE}" | sha256sum | awk '{print "sha256:"$1}')
|
||||
echo "multiarch_digest=${multiarch_digest}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
if [[ -n "${LATEST_TARGET}" ]]; then
|
||||
latest_digest=$(docker buildx imagetools inspect --raw "${LATEST_TARGET}" | sha256sum | awk '{print "sha256:"$1}')
|
||||
echo "latest_digest=${latest_digest}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# Get the latest tag name for attestation
|
||||
- name: Get latest tag name
|
||||
id: latest_tag
|
||||
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
|
||||
run: echo "tag=$(./scripts/image_tag.sh --version latest)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# If this is the highest version according to semver, also attest the "latest" tag
|
||||
- name: GitHub Attestation for "latest" Docker image
|
||||
id: attest_latest
|
||||
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
|
||||
- name: GitHub Attestation for Docker image
|
||||
id: attest_main
|
||||
if: ${{ !inputs.dry_run && steps.docker_digests.outputs.multiarch_digest != '' }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
with:
|
||||
subject-name: ${{ steps.latest_tag.outputs.tag }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
predicate: |
|
||||
{
|
||||
"buildType": "https://github.com/actions/runner-images/",
|
||||
"builder": {
|
||||
"id": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
},
|
||||
"invocation": {
|
||||
"configSource": {
|
||||
"uri": "git+https://github.com/${{ github.repository }}@${{ github.ref }}",
|
||||
"digest": {
|
||||
"sha1": "${{ github.sha }}"
|
||||
},
|
||||
"entryPoint": ".github/workflows/release.yaml"
|
||||
},
|
||||
"environment": {
|
||||
"github_workflow": "${{ github.workflow }}",
|
||||
"github_run_id": "${{ github.run_id }}"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"buildInvocationID": "${{ github.run_id }}",
|
||||
"completeness": {
|
||||
"environment": true,
|
||||
"materials": true
|
||||
}
|
||||
}
|
||||
}
|
||||
subject-name: ghcr.io/coder/coder
|
||||
subject-digest: ${{ steps.docker_digests.outputs.multiarch_digest }}
|
||||
push-to-registry: true
|
||||
|
||||
- name: GitHub Attestation for "latest" Docker image
|
||||
id: attest_latest
|
||||
if: ${{ !inputs.dry_run && steps.docker_digests.outputs.latest_digest != '' }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
with:
|
||||
subject-name: ghcr.io/coder/coder
|
||||
subject-digest: ${{ steps.docker_digests.outputs.latest_digest }}
|
||||
push-to-registry: true
|
||||
|
||||
# Report attestation failures but don't fail the workflow
|
||||
@@ -607,6 +542,9 @@ jobs:
|
||||
if [[ $CODER_RELEASE_CHANNEL == "stable" ]]; then
|
||||
publish_args+=(--stable)
|
||||
fi
|
||||
if [[ $CODER_RELEASE_CHANNEL == "rc" ]]; then
|
||||
publish_args+=(--rc)
|
||||
fi
|
||||
if [[ $CODER_DRY_RUN == *t* ]]; then
|
||||
publish_args+=(--dry-run)
|
||||
fi
|
||||
@@ -639,6 +577,35 @@ jobs:
|
||||
VERSION: ${{ steps.version.outputs.version }}
|
||||
CREATED_LATEST_TAG: ${{ steps.build_docker.outputs.created_latest_tag }}
|
||||
|
||||
# Mark the Linear release as shipped.
|
||||
- name: Extract Linear release version
|
||||
if: ${{ !inputs.dry_run }}
|
||||
id: linear_version
|
||||
run: |
|
||||
# Skip RC releases — they must not complete the Linear release.
|
||||
if [[ "$VERSION" == *-rc* ]]; then
|
||||
echo "RC release (${VERSION}), skipping Linear release completion."
|
||||
echo "skip=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
# Strip patch to get the Linear release version (e.g. 2.32.0 -> 2.32).
|
||||
linear_version=$(echo "$VERSION" | cut -d. -f1,2)
|
||||
echo "version=$linear_version" >> "$GITHUB_OUTPUT"
|
||||
echo "skip=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Completing Linear release ${linear_version}"
|
||||
env:
|
||||
VERSION: ${{ steps.version.outputs.version }}
|
||||
|
||||
- name: Complete Linear release
|
||||
if: ${{ !inputs.dry_run && steps.linear_version.outputs.skip != 'true' }}
|
||||
continue-on-error: true
|
||||
uses: linear/linear-release-action@755d50b5adb7dd42b976ee9334952745d62ceb2d # v0.6.0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: complete
|
||||
version: ${{ steps.linear_version.outputs.version }}
|
||||
timeout: 300
|
||||
|
||||
- name: Authenticate to Google Cloud
|
||||
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
|
||||
with:
|
||||
@@ -690,7 +657,7 @@ jobs:
|
||||
retention-days: 7
|
||||
|
||||
- name: Send repository-dispatch event
|
||||
if: ${{ !inputs.dry_run }}
|
||||
if: ${{ !inputs.dry_run && inputs.release_channel != 'rc' }}
|
||||
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
|
||||
with:
|
||||
token: ${{ secrets.CDRCI_GITHUB_TOKEN }}
|
||||
@@ -778,7 +745,7 @@ jobs:
|
||||
name: Publish to winget-pkgs
|
||||
runs-on: windows-latest
|
||||
needs: release
|
||||
if: ${{ !inputs.dry_run }}
|
||||
if: ${{ !inputs.dry_run && inputs.release_channel != 'rc' }}
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Delete PR Cleanup workflow runs
|
||||
uses: Mattraks/delete-workflow-runs@5bf9a1dac5c4d041c029f0a8370ddf0c5cb5aeb7 # v2.1.0
|
||||
uses: Mattraks/delete-workflow-runs@b3018382ca039b53d238908238bd35d1fb14f8ee # v2.1.0
|
||||
with:
|
||||
token: ${{ github.token }}
|
||||
repository: ${{ github.repository }}
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
delete_workflow_pattern: pr-cleanup.yaml
|
||||
|
||||
- name: Delete PR Deploy workflow skipped runs
|
||||
uses: Mattraks/delete-workflow-runs@5bf9a1dac5c4d041c029f0a8370ddf0c5cb5aeb7 # v2.1.0
|
||||
uses: Mattraks/delete-workflow-runs@b3018382ca039b53d238908238bd35d1fb14f8ee # v2.1.0
|
||||
with:
|
||||
token: ${{ github.token }}
|
||||
repository: ${{ github.repository }}
|
||||
|
||||
@@ -46,8 +46,16 @@ jobs:
|
||||
echo " replacement: \"https://github.com/coder/coder/tree/${HEAD_SHA}/\""
|
||||
} >> .github/.linkspector.yml
|
||||
|
||||
# TODO: Remove this workaround once action-linkspector sets
|
||||
# package-manager-cache: false in its internal setup-node step.
|
||||
# See: https://github.com/UmbrellaDocs/action-linkspector/issues/54
|
||||
- name: Enable corepack and create pnpm store
|
||||
run: |
|
||||
corepack enable pnpm
|
||||
mkdir -p "$(pnpm store path --silent)"
|
||||
|
||||
- name: Check Markdown links
|
||||
uses: umbrelladocs/action-linkspector@652f85bc57bb1e7d4327260decc10aa68f7694c3 # v1.4.0
|
||||
uses: umbrelladocs/action-linkspector@37c85bcde51b30bf929936502bac6bfb7e8f0a4d # v1.4.1
|
||||
id: markdown-link-check
|
||||
# checks all markdown files from /docs including all subfolders
|
||||
with:
|
||||
|
||||
@@ -54,6 +54,7 @@ site/stats/
|
||||
*.tfstate.backup
|
||||
*.tfplan
|
||||
*.lock.hcl
|
||||
!provisioner/terraform/testdata/resources/.terraform.lock.hcl
|
||||
.terraform/
|
||||
!coderd/testdata/parameters/modules/.terraform/
|
||||
!provisioner/terraform/testdata/modules-source-caching/.terraform/
|
||||
|
||||
@@ -988,6 +988,7 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g
|
||||
|
||||
codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go
|
||||
go generate ./codersdk/workspacesdk/agentconnmock/
|
||||
./scripts/format_go_file.sh "$@"
|
||||
touch "$@"
|
||||
|
||||
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
|
||||
@@ -1260,11 +1261,21 @@ provisioner/terraform/testdata/.gen-golden: $(wildcard provisioner/terraform/tes
|
||||
touch "$@"
|
||||
|
||||
provisioner/terraform/testdata/version:
|
||||
if [[ "$(shell cat provisioner/terraform/testdata/version.txt)" != "$(shell terraform version -json | jq -r '.terraform_version')" ]]; then
|
||||
./provisioner/terraform/testdata/generate.sh
|
||||
@tf_match=true; \
|
||||
if [[ "$$(cat provisioner/terraform/testdata/version.txt)" != \
|
||||
"$$(terraform version -json | jq -r '.terraform_version')" ]]; then \
|
||||
tf_match=false; \
|
||||
fi; \
|
||||
if ! $$tf_match || \
|
||||
! ./provisioner/terraform/testdata/generate.sh --check; then \
|
||||
./provisioner/terraform/testdata/generate.sh; \
|
||||
fi
|
||||
.PHONY: provisioner/terraform/testdata/version
|
||||
|
||||
update-terraform-testdata:
|
||||
./provisioner/terraform/testdata/generate.sh --upgrade
|
||||
.PHONY: update-terraform-testdata
|
||||
|
||||
# Set the retry flags if TEST_RETRIES is set
|
||||
ifdef TEST_RETRIES
|
||||
GOTESTSUM_RETRY_FLAGS := --rerun-fails=$(TEST_RETRIES)
|
||||
|
||||
+29
-4
@@ -38,6 +38,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/clistat"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
@@ -50,6 +51,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
|
||||
"github.com/coder/coder/v2/agent/reconnectingpty"
|
||||
"github.com/coder/coder/v2/agent/x/agentdesktop"
|
||||
"github.com/coder/coder/v2/agent/x/agentmcp"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/cli/gitauth"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
@@ -307,10 +309,13 @@ type agent struct {
|
||||
containerAPI *agentcontainers.API
|
||||
gitAPIOptions []agentgit.Option
|
||||
|
||||
filesAPI *agentfiles.API
|
||||
gitAPI *agentgit.API
|
||||
processAPI *agentproc.API
|
||||
desktopAPI *agentdesktop.API
|
||||
filesAPI *agentfiles.API
|
||||
gitAPI *agentgit.API
|
||||
processAPI *agentproc.API
|
||||
desktopAPI *agentdesktop.API
|
||||
mcpManager *agentmcp.Manager
|
||||
mcpAPI *agentmcp.API
|
||||
contextConfigAPI *agentcontextconfig.API
|
||||
|
||||
socketServerEnabled bool
|
||||
socketPath string
|
||||
@@ -396,6 +401,14 @@ func (a *agent) init() {
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp"))
|
||||
a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager)
|
||||
a.contextConfigAPI = agentcontextconfig.NewAPI(func() string {
|
||||
if m := a.manifest.Load(); m != nil {
|
||||
return m.Directory
|
||||
}
|
||||
return ""
|
||||
})
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
a.sshServer,
|
||||
@@ -1348,6 +1361,14 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
|
||||
}
|
||||
a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur)
|
||||
a.scriptRunner.StartCron()
|
||||
|
||||
// Connect to workspace MCP servers after the
|
||||
// lifecycle transition to avoid delaying Ready.
|
||||
// This runs inside the tracked goroutine so it
|
||||
// is properly awaited on shutdown.
|
||||
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.Config().MCPConfigFiles); mcpErr != nil {
|
||||
a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("track conn goroutine: %w", err)
|
||||
@@ -2070,6 +2091,10 @@ func (a *agent) Close() error {
|
||||
a.logger.Error(a.hardCtx, "desktop API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if err := a.mcpManager.Close(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "mcp manager close", slog.Error(err))
|
||||
}
|
||||
|
||||
if a.boundaryLogProxy != nil {
|
||||
err = a.boundaryLogProxy.Close()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -8,10 +10,22 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
agentsdk "github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// platformAbsPath constructs an absolute path that is valid
|
||||
// on the current platform. On Windows, paths must include a
|
||||
// drive letter to be considered absolute.
|
||||
func platformAbsPath(parts ...string) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return `C:\` + filepath.Join(parts...)
|
||||
}
|
||||
return "/" + filepath.Join(parts...)
|
||||
}
|
||||
|
||||
// TestReportConnectionEmpty tests that reportConnection() doesn't choke if given an empty IP string, which is what we
|
||||
// send if we cannot get the remote address.
|
||||
func TestReportConnectionEmpty(t *testing.T) {
|
||||
@@ -42,3 +56,41 @@ func TestReportConnectionEmpty(t *testing.T) {
|
||||
require.Equal(t, proto.Connection_DISCONNECT, req1.GetConnection().GetAction())
|
||||
require.Equal(t, "because", req1.GetConnection().GetReason())
|
||||
}
|
||||
|
||||
func TestContextConfigAPI_InitOnce(t *testing.T) {
|
||||
// Not parallel: uses t.Setenv to clear env vars.
|
||||
|
||||
// Clear env vars so defaults are used and the test is
|
||||
// hermetic regardless of the surrounding environment.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
// After the fix, contextConfigAPI is set once in init() and
|
||||
// never reassigned. Config() evaluates lazily via the
|
||||
// manifest, so there is no concurrent write to race with.
|
||||
dir1 := platformAbsPath("dir1")
|
||||
dir2 := platformAbsPath("dir2")
|
||||
|
||||
a := &agent{}
|
||||
a.manifest.Store(&agentsdk.Manifest{Directory: dir1})
|
||||
a.contextConfigAPI = agentcontextconfig.NewAPI(func() string {
|
||||
if m := a.manifest.Load(); m != nil {
|
||||
return m.Directory
|
||||
}
|
||||
return ""
|
||||
})
|
||||
|
||||
cfg1 := a.contextConfigAPI.Config()
|
||||
require.NotEmpty(t, cfg1.MCPConfigFiles)
|
||||
require.Contains(t, cfg1.MCPConfigFiles[0], dir1)
|
||||
|
||||
// Simulate manifest update on reconnection — no field
|
||||
// reassignment needed, the lazy closure picks it up.
|
||||
a.manifest.Store(&agentsdk.Manifest{Directory: dir2})
|
||||
cfg2 := a.contextConfigAPI.Config()
|
||||
require.NotEmpty(t, cfg2.MCPConfigFiles)
|
||||
require.Contains(t, cfg2.MCPConfigFiles[0], dir2)
|
||||
}
|
||||
|
||||
+15
-8
@@ -3007,7 +3007,7 @@ func TestAgent_Speedtest(t *testing.T) {
|
||||
|
||||
func TestAgent_Reconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := testutil.Logger(t)
|
||||
// After the agent is disconnected from a coordinator, it's supposed
|
||||
// to reconnect!
|
||||
@@ -3020,7 +3020,8 @@ func TestAgent_Reconnect(t *testing.T) {
|
||||
logger,
|
||||
agentID,
|
||||
agentsdk.Manifest{
|
||||
DERPMap: derpMap,
|
||||
DERPMap: derpMap,
|
||||
Directory: "/test/workspace",
|
||||
},
|
||||
statsCh,
|
||||
fCoordinator,
|
||||
@@ -3033,13 +3034,19 @@ func TestAgent_Reconnect(t *testing.T) {
|
||||
})
|
||||
defer closer.Close()
|
||||
|
||||
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
require.Equal(t, client.GetNumRefreshTokenCalls(), 1)
|
||||
close(call1.Resps) // hang up
|
||||
// expect reconnect
|
||||
// Each iteration forces the agent to reconnect by closing
|
||||
// the current coordinate call while the tracked HTTP server
|
||||
// goroutine (from connection 1's createTailnet) is still
|
||||
// alive, widening the race window.
|
||||
const reconnections = 5
|
||||
for i := range reconnections {
|
||||
call := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
require.Equal(t, i+1, client.GetNumRefreshTokenCalls())
|
||||
close(call.Resps) // hang up — triggers reconnect
|
||||
}
|
||||
// Verify final reconnect succeeds.
|
||||
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
// Check that the agent refreshes the token when it reconnects.
|
||||
require.Equal(t, client.GetNumRefreshTokenCalls(), 2)
|
||||
require.Equal(t, reconnections+1, client.GetNumRefreshTokenCalls())
|
||||
closer.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -159,7 +159,6 @@ func TestConvertDockerVolume(t *testing.T) {
|
||||
func TestConvertDockerInspect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:paralleltest // variable recapture no longer required
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
expect []codersdk.WorkspaceAgentContainer
|
||||
@@ -388,7 +387,6 @@ func TestConvertDockerInspect(t *testing.T) {
|
||||
},
|
||||
},
|
||||
} {
|
||||
// nolint:paralleltest // variable recapture no longer required
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
bs, err := os.ReadFile(filepath.Join("testdata", tt.name, "docker_inspect.json"))
|
||||
|
||||
@@ -166,7 +166,6 @@ func TestDockerEnvInfoer(t *testing.T) {
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
require.NoError(t, err, "Could not connect to docker")
|
||||
// nolint:paralleltest // variable recapture no longer required
|
||||
for idx, tt := range []struct {
|
||||
image string
|
||||
labels map[string]string
|
||||
@@ -223,7 +222,6 @@ func TestDockerEnvInfoer(t *testing.T) {
|
||||
expectedUserShell: "/bin/bash",
|
||||
},
|
||||
} {
|
||||
//nolint:paralleltest // variable recapture no longer required
|
||||
t.Run(fmt.Sprintf("#%d", idx), func(t *testing.T) {
|
||||
// Start a container with the given image
|
||||
// and environment variables
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package agentcontextconfig
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// Env var names for context configuration. Prefixed with EXP_
|
||||
// to indicate these are experimental and may change.
|
||||
const (
|
||||
EnvInstructionsDirs = "CODER_AGENT_EXP_INSTRUCTIONS_DIRS"
|
||||
EnvInstructionsFile = "CODER_AGENT_EXP_INSTRUCTIONS_FILE"
|
||||
EnvSkillsDirs = "CODER_AGENT_EXP_SKILLS_DIRS"
|
||||
EnvSkillMetaFile = "CODER_AGENT_EXP_SKILL_META_FILE"
|
||||
EnvMCPConfigFiles = "CODER_AGENT_EXP_MCP_CONFIG_FILES"
|
||||
)
|
||||
|
||||
// Defaults are defined in codersdk/workspacesdk so both
|
||||
// the agent and server can reference them without a
|
||||
// cross-layer import.
|
||||
|
||||
// API exposes the resolved context configuration through the
|
||||
// agent's HTTP API.
|
||||
type API struct {
|
||||
workingDir func() string
|
||||
}
|
||||
|
||||
// NewAPI accepts a closure that returns the working directory.
|
||||
// The directory is evaluated lazily on each call to Config(),
|
||||
// so the caller can update it after construction.
|
||||
func NewAPI(workingDir func() string) *API {
|
||||
if workingDir == nil {
|
||||
workingDir = func() string { return "" }
|
||||
}
|
||||
return &API{workingDir: workingDir}
|
||||
}
|
||||
|
||||
// Config reads env vars and resolves paths. Exported for use
|
||||
// by the MCP manager and tests.
|
||||
func Config(workingDir string) workspacesdk.ContextConfigResponse {
|
||||
// TrimSpace all env vars before cmp.Or so that a
|
||||
// whitespace-only value falls through to the default
|
||||
// consistently. ResolvePaths also trims each comma-
|
||||
// separated entry, but without pre-trimming here a
|
||||
// bare " " would bypass cmp.Or and produce nil.
|
||||
instructionsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsDirs)), workspacesdk.DefaultInstructionsDir)
|
||||
instructionsFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsFile)), workspacesdk.DefaultInstructionsFile)
|
||||
skillsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillsDirs)), workspacesdk.DefaultSkillsDir)
|
||||
skillMetaFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillMetaFile)), workspacesdk.DefaultSkillMetaFile)
|
||||
mcpConfigFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvMCPConfigFiles)), workspacesdk.DefaultMCPConfigFile)
|
||||
|
||||
return workspacesdk.ContextConfigResponse{
|
||||
InstructionsDirs: ResolvePaths(instructionsDir, workingDir),
|
||||
InstructionsFile: instructionsFile,
|
||||
SkillsDirs: ResolvePaths(skillsDir, workingDir),
|
||||
SkillMetaFile: skillMetaFile,
|
||||
MCPConfigFiles: ResolvePaths(mcpConfigFile, workingDir),
|
||||
}
|
||||
}
|
||||
|
||||
// Config returns the resolved config for use by other agent
|
||||
// components (e.g. MCP manager).
|
||||
func (api *API) Config() workspacesdk.ContextConfigResponse {
|
||||
return Config(api.workingDir())
|
||||
}
|
||||
|
||||
// Routes returns the HTTP handler for the context config
|
||||
// endpoint.
|
||||
func (api *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/", api.handleGet)
|
||||
return r
|
||||
}
|
||||
|
||||
func (api *API) handleGet(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, api.Config())
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package agentcontextconfig_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
// Clear all env vars so defaults are used.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, workspacesdk.DefaultInstructionsFile, cfg.InstructionsFile)
|
||||
require.Equal(t, workspacesdk.DefaultSkillMetaFile, cfg.SkillMetaFile)
|
||||
// Default instructions dir is "~/.coder" which resolves
|
||||
// to the home directory.
|
||||
require.Equal(t, []string{filepath.Join(fakeHome, ".coder")}, cfg.InstructionsDirs)
|
||||
// Default skills dir is ".agents/skills" (relative),
|
||||
// resolved against the working directory.
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".agents", "skills")}, cfg.SkillsDirs)
|
||||
// Default MCP config file is ".mcp.json" (relative),
|
||||
// resolved against the working directory.
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, cfg.MCPConfigFiles)
|
||||
})
|
||||
|
||||
t.Run("CustomEnvVars", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
optInstructions := platformAbsPath("opt", "instructions")
|
||||
optSkills := platformAbsPath("opt", "skills")
|
||||
optMCP := platformAbsPath("opt", "mcp.json")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, "CUSTOM.md", cfg.InstructionsFile)
|
||||
require.Equal(t, "META.yaml", cfg.SkillMetaFile)
|
||||
require.Equal(t, []string{optInstructions}, cfg.InstructionsDirs)
|
||||
require.Equal(t, []string{optSkills}, cfg.SkillsDirs)
|
||||
require.Equal(t, []string{optMCP}, cfg.MCPConfigFiles)
|
||||
})
|
||||
|
||||
t.Run("WhitespaceInFileNames", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, "CLAUDE.md", cfg.InstructionsFile)
|
||||
})
|
||||
|
||||
t.Run("CommaSeparatedDirs", func(t *testing.T) {
|
||||
a := platformAbsPath("opt", "a")
|
||||
b := platformAbsPath("opt", "b")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, []string{a, b}, cfg.InstructionsDirs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewAPI_LazyDirectory(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
dir := ""
|
||||
api := agentcontextconfig.NewAPI(func() string { return dir })
|
||||
|
||||
// Before directory is set, relative paths resolve to nothing.
|
||||
cfg := api.Config()
|
||||
require.Empty(t, cfg.SkillsDirs)
|
||||
require.Empty(t, cfg.MCPConfigFiles)
|
||||
|
||||
// After setting the directory, Config() picks it up lazily.
|
||||
dir = platformAbsPath("work")
|
||||
cfg = api.Config()
|
||||
require.NotEmpty(t, cfg.SkillsDirs)
|
||||
require.Equal(t, []string{filepath.Join(dir, ".agents", "skills")}, cfg.SkillsDirs)
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package agentcontextconfig
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ResolvePath resolves a single path that may be absolute,
|
||||
// home-relative (~/ or ~), or relative to the given base
|
||||
// directory. Returns an absolute path. Empty input returns empty.
|
||||
func ResolvePath(raw, baseDir string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
switch {
|
||||
case raw == "~":
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
case strings.HasPrefix(raw, "~/"):
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, raw[2:])
|
||||
case filepath.IsAbs(raw):
|
||||
return raw
|
||||
default:
|
||||
if baseDir == "" {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(baseDir, raw)
|
||||
}
|
||||
}
|
||||
|
||||
// ResolvePaths splits a comma-separated list of paths and
|
||||
// resolves each entry independently. Empty entries and entries
|
||||
// that resolve to empty strings are skipped.
|
||||
func ResolvePaths(raw, baseDir string) []string {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if resolved := ResolvePath(p, baseDir); resolved != "" {
|
||||
out = append(out, resolved)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package agentcontextconfig_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
)
|
||||
|
||||
// platformAbsPath constructs an absolute path that is valid
|
||||
// on the current platform. On Windows paths must include a
|
||||
// drive letter to be considered absolute.
|
||||
func platformAbsPath(parts ...string) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return `C:\` + filepath.Join(parts...)
|
||||
}
|
||||
return "/" + filepath.Join(parts...)
|
||||
}
|
||||
|
||||
func TestResolvePath(t *testing.T) { //nolint:tparallel // subtests using t.Setenv cannot be parallel
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, "", agentcontextconfig.ResolvePath("", platformAbsPath("base")))
|
||||
})
|
||||
|
||||
t.Run("WhitespaceOnly", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, "", agentcontextconfig.ResolvePath(" ", platformAbsPath("base")))
|
||||
})
|
||||
|
||||
// Tests that use t.Setenv cannot be parallel.
|
||||
t.Run("TildeAlone", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
got := agentcontextconfig.ResolvePath("~", platformAbsPath("base"))
|
||||
require.Equal(t, fakeHome, got)
|
||||
})
|
||||
|
||||
t.Run("TildeSlashPath", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
got := agentcontextconfig.ResolvePath("~/docs/readme", platformAbsPath("base"))
|
||||
require.Equal(t, filepath.Join(fakeHome, "docs", "readme"), got)
|
||||
})
|
||||
|
||||
t.Run("AbsolutePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := platformAbsPath("etc", "coder")
|
||||
got := agentcontextconfig.ResolvePath(p, platformAbsPath("base"))
|
||||
require.Equal(t, p, got)
|
||||
})
|
||||
|
||||
t.Run("RelativePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
base := platformAbsPath("work")
|
||||
got := agentcontextconfig.ResolvePath("foo/bar", base)
|
||||
require.Equal(t, filepath.Join(base, "foo", "bar"), got)
|
||||
})
|
||||
|
||||
t.Run("RelativePathWithWhitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
base := platformAbsPath("work")
|
||||
got := agentcontextconfig.ResolvePath(" foo/bar ", base)
|
||||
require.Equal(t, filepath.Join(base, "foo", "bar"), got)
|
||||
})
|
||||
|
||||
t.Run("RelativePathWithEmptyBaseDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := agentcontextconfig.ResolvePath(".agents/skills", "")
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolvePath_HomeUnset(t *testing.T) {
|
||||
// Cannot be parallel — modifies HOME env var.
|
||||
t.Setenv("HOME", "")
|
||||
// Also clear USERPROFILE for Windows compatibility.
|
||||
t.Setenv("USERPROFILE", "")
|
||||
|
||||
require.Equal(t, "", agentcontextconfig.ResolvePath("~", platformAbsPath("base")))
|
||||
require.Equal(t, "", agentcontextconfig.ResolvePath("~/docs", platformAbsPath("base")))
|
||||
}
|
||||
|
||||
func TestResolvePaths(t *testing.T) { //nolint:tparallel // subtests using t.Setenv cannot be parallel
|
||||
t.Run("EmptyString", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, agentcontextconfig.ResolvePaths("", platformAbsPath("base")))
|
||||
})
|
||||
|
||||
t.Run("WhitespaceOnly", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, agentcontextconfig.ResolvePaths(" ", platformAbsPath("base")))
|
||||
})
|
||||
|
||||
t.Run("SingleEntry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := platformAbsPath("abs", "path")
|
||||
got := agentcontextconfig.ResolvePaths(p, platformAbsPath("base"))
|
||||
require.Equal(t, []string{p}, got)
|
||||
})
|
||||
|
||||
// Tests that use t.Setenv cannot be parallel.
|
||||
t.Run("MultipleEntries", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
b := platformAbsPath("b")
|
||||
base := platformAbsPath("base")
|
||||
got := agentcontextconfig.ResolvePaths("~/a,"+b+",rel", base)
|
||||
require.Equal(t, []string{
|
||||
filepath.Join(fakeHome, "a"),
|
||||
b,
|
||||
filepath.Join(base, "rel"),
|
||||
}, got)
|
||||
})
|
||||
|
||||
t.Run("TrimsWhitespace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := platformAbsPath("a")
|
||||
b := platformAbsPath("b")
|
||||
got := agentcontextconfig.ResolvePaths(" "+a+" , "+b+" ", platformAbsPath("base"))
|
||||
require.Equal(t, []string{a, b}, got)
|
||||
})
|
||||
|
||||
t.Run("SkipsEmptyEntries", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := platformAbsPath("a")
|
||||
b := platformAbsPath("b")
|
||||
got := agentcontextconfig.ResolvePaths(a+",,"+b+",", platformAbsPath("base"))
|
||||
require.Equal(t, []string{a, b}, got)
|
||||
})
|
||||
|
||||
t.Run("TrailingComma", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := platformAbsPath("only")
|
||||
got := agentcontextconfig.ResolvePaths(p+",", platformAbsPath("base"))
|
||||
require.Equal(t, []string{p}, got)
|
||||
})
|
||||
|
||||
t.Run("RelativePathSkippedWhenBaseDirEmpty", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
got := agentcontextconfig.ResolvePaths("~/.coder,.agents/skills", "")
|
||||
require.Equal(t, []string{filepath.Join(fakeHome, ".coder")}, got)
|
||||
})
|
||||
}
|
||||
@@ -148,6 +148,11 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
for k, v := range req.Env {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
// Propagate the chat ID so child processes (e.g.
|
||||
// GIT_ASKPASS) can send it back to the server.
|
||||
if chatID != "" {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_CHAT_ID=%s", chatID))
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
|
||||
@@ -211,7 +211,7 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
stderr, err := sess.StderrPipe()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, sess.Shell())
|
||||
require.NoError(t, sess.Start("sh"))
|
||||
|
||||
// The SSH server lazily starts the session. We need to write a command
|
||||
// and read back to ensure the X11 forwarding is started.
|
||||
|
||||
@@ -31,6 +31,8 @@ func (a *agent) apiHandler() http.Handler {
|
||||
r.Mount("/api/v0/git", a.gitAPI.Routes())
|
||||
r.Mount("/api/v0/processes", a.processAPI.Routes())
|
||||
r.Mount("/api/v0/desktop", a.desktopAPI.Routes())
|
||||
r.Mount("/api/v0/mcp", a.mcpAPI.Routes())
|
||||
r.Mount("/api/v0/context-config", a.contextConfigAPI.Routes())
|
||||
|
||||
if a.devcontainers {
|
||||
r.Mount("/api/v0/containers", a.containerAPI.Routes())
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// API exposes MCP tool discovery and call proxying through the
|
||||
// agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
// NewAPI creates a new MCP API handler backed by the given
|
||||
// manager.
|
||||
func NewAPI(logger slog.Logger, manager *Manager) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
manager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
// Routes returns the HTTP handler for MCP-related routes.
|
||||
func (api *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/tools", api.handleListTools)
|
||||
r.Post("/call-tool", api.handleCallTool)
|
||||
return r
|
||||
}
|
||||
|
||||
// handleListTools returns the cached MCP tool definitions,
|
||||
// optionally refreshing them first if ?refresh=true is set.
|
||||
func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Allow callers to force a tool re-scan before listing.
|
||||
if r.URL.Query().Get("refresh") == "true" {
|
||||
if err := api.manager.RefreshTools(ctx); err != nil {
|
||||
api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
tools := api.manager.Tools()
|
||||
// Ensure non-nil so JSON serialization returns [] not null.
|
||||
if tools == nil {
|
||||
tools = []workspacesdk.MCPToolInfo{}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListMCPToolsResponse{
|
||||
Tools: tools,
|
||||
})
|
||||
}
|
||||
|
||||
// handleCallTool proxies a tool invocation to the appropriate
|
||||
// MCP server based on the tool name prefix.
|
||||
func (api *API) handleCallTool(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req workspacesdk.CallMCPToolRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := api.manager.CallTool(ctx, req)
|
||||
if err != nil {
|
||||
status := http.StatusBadGateway
|
||||
if errors.Is(err, ErrInvalidToolName) {
|
||||
status = http.StatusBadRequest
|
||||
} else if errors.Is(err, ErrUnknownServer) {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
httpapi.Write(ctx, rw, status, codersdk.Response{
|
||||
Message: "MCP tool call failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ServerConfig describes a single MCP server parsed from a .mcp.json file.
|
||||
type ServerConfig struct {
|
||||
Name string `json:"name"`
|
||||
Transport string `json:"type"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
}
|
||||
|
||||
// mcpConfigFile mirrors the on-disk .mcp.json schema.
|
||||
type mcpConfigFile struct {
|
||||
MCPServers map[string]json.RawMessage `json:"mcpServers"`
|
||||
}
|
||||
|
||||
// mcpServerEntry is a single server block inside mcpServers.
|
||||
type mcpServerEntry struct {
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
Type string `json:"type"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
}
|
||||
|
||||
// ParseConfig reads a .mcp.json file at path and returns the declared
|
||||
// MCP servers sorted by name. It returns an empty slice when the
|
||||
// mcpServers key is missing or empty.
|
||||
func ParseConfig(path string) ([]ServerConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read mcp config %q: %w", path, err)
|
||||
}
|
||||
|
||||
var cfg mcpConfigFile
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, xerrors.Errorf("parse mcp config %q: %w", path, err)
|
||||
}
|
||||
|
||||
if len(cfg.MCPServers) == 0 {
|
||||
return []ServerConfig{}, nil
|
||||
}
|
||||
|
||||
servers := make([]ServerConfig, 0, len(cfg.MCPServers))
|
||||
for name, raw := range cfg.MCPServers {
|
||||
var entry mcpServerEntry
|
||||
if err := json.Unmarshal(raw, &entry); err != nil {
|
||||
return nil, xerrors.Errorf("parse server %q in %q: %w", name, path, err)
|
||||
}
|
||||
|
||||
if strings.Contains(name, ToolNameSep) || strings.HasPrefix(name, "_") || strings.HasSuffix(name, "_") {
|
||||
return nil, xerrors.Errorf("server name %q in %q contains reserved separator %q or leading/trailing underscore", name, path, ToolNameSep)
|
||||
}
|
||||
|
||||
transport := inferTransport(entry)
|
||||
|
||||
if transport == "" {
|
||||
return nil, xerrors.Errorf("server %q in %q has no command or url", name, path)
|
||||
}
|
||||
|
||||
resolveEnvVars(entry.Env)
|
||||
|
||||
servers = append(servers, ServerConfig{
|
||||
Name: name,
|
||||
Transport: transport,
|
||||
Command: entry.Command,
|
||||
Args: entry.Args,
|
||||
Env: entry.Env,
|
||||
URL: entry.URL,
|
||||
Headers: entry.Headers,
|
||||
})
|
||||
}
|
||||
|
||||
slices.SortFunc(servers, func(a, b ServerConfig) int {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// inferTransport determines the transport type for a server entry.
|
||||
// An explicit "type" field takes priority; otherwise the presence
|
||||
// of "command" implies stdio and "url" implies http.
|
||||
func inferTransport(e mcpServerEntry) string {
|
||||
if e.Type != "" {
|
||||
return e.Type
|
||||
}
|
||||
if e.Command != "" {
|
||||
return "stdio"
|
||||
}
|
||||
if e.URL != "" {
|
||||
return "http"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveEnvVars expands ${VAR} references in env map values
|
||||
// using the current process environment.
|
||||
func resolveEnvVars(env map[string]string) {
|
||||
for k, v := range env {
|
||||
env[k] = os.Expand(v, os.Getenv)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
package agentmcp_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/x/agentmcp"
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected []agentmcp.ServerConfig
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "StdioServer",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"my-server": map[string]any{
|
||||
"command": "npx",
|
||||
"args": []string{"-y", "@example/mcp-server"},
|
||||
"env": map[string]string{"FOO": "bar"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
expected: []agentmcp.ServerConfig{
|
||||
{
|
||||
Name: "my-server",
|
||||
Transport: "stdio",
|
||||
Command: "npx",
|
||||
Args: []string{"-y", "@example/mcp-server"},
|
||||
Env: map[string]string{"FOO": "bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "HTTPServer",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"remote": map[string]any{
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": map[string]string{"Authorization": "Bearer tok"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
expected: []agentmcp.ServerConfig{
|
||||
{
|
||||
Name: "remote",
|
||||
Transport: "http",
|
||||
URL: "https://example.com/mcp",
|
||||
Headers: map[string]string{"Authorization": "Bearer tok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SSEServer",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"events": map[string]any{
|
||||
"type": "sse",
|
||||
"url": "https://example.com/sse",
|
||||
},
|
||||
},
|
||||
}),
|
||||
expected: []agentmcp.ServerConfig{
|
||||
{
|
||||
Name: "events",
|
||||
Transport: "sse",
|
||||
URL: "https://example.com/sse",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ExplicitTypeOverridesInference",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"hybrid": map[string]any{
|
||||
"command": "some-binary",
|
||||
"type": "http",
|
||||
},
|
||||
},
|
||||
}),
|
||||
expected: []agentmcp.ServerConfig{
|
||||
{
|
||||
Name: "hybrid",
|
||||
Transport: "http",
|
||||
Command: "some-binary",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "EnvVarPassthrough",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"srv": map[string]any{
|
||||
"command": "run",
|
||||
"env": map[string]string{"PLAIN": "literal-value"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
expected: []agentmcp.ServerConfig{
|
||||
{
|
||||
Name: "srv",
|
||||
Transport: "stdio",
|
||||
Command: "run",
|
||||
Env: map[string]string{"PLAIN": "literal-value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "EmptyMCPServers",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{},
|
||||
}),
|
||||
expected: []agentmcp.ServerConfig{},
|
||||
},
|
||||
{
|
||||
name: "MalformedJSON",
|
||||
content: `{not valid json`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "ServerNameContainsSeparator",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"bad__name": map[string]any{"command": "run"},
|
||||
},
|
||||
}),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "ServerNameTrailingUnderscore",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"server_": map[string]any{"command": "run"},
|
||||
},
|
||||
}),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "ServerNameLeadingUnderscore",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"_server": map[string]any{"command": "run"},
|
||||
},
|
||||
}),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyTransport", content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"empty": map[string]any{},
|
||||
},
|
||||
}),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "MissingMCPServersKey",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"servers": map[string]any{},
|
||||
}),
|
||||
expected: []agentmcp.ServerConfig{},
|
||||
},
|
||||
{
|
||||
name: "MultipleServersSortedByName",
|
||||
content: mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"zeta": map[string]any{"command": "z"},
|
||||
"alpha": map[string]any{"command": "a"},
|
||||
"mu": map[string]any{"command": "m"},
|
||||
},
|
||||
}),
|
||||
expected: []agentmcp.ServerConfig{
|
||||
{Name: "alpha", Transport: "stdio", Command: "a"},
|
||||
{Name: "mu", Transport: "stdio", Command: "m"},
|
||||
{Name: "zeta", Transport: "stdio", Command: "z"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, ".mcp.json")
|
||||
err := os.WriteFile(path, []byte(tt.content), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := agentmcp.ParseConfig(path)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseConfig_EnvVarInterpolation verifies that ${VAR} references
|
||||
// in env values are resolved from the process environment. This test
|
||||
// cannot be parallel because t.Setenv is incompatible with t.Parallel.
|
||||
func TestParseConfig_EnvVarInterpolation(t *testing.T) {
|
||||
t.Setenv("TEST_MCP_TOKEN", "secret123")
|
||||
|
||||
content := mustJSON(t, map[string]any{
|
||||
"mcpServers": map[string]any{
|
||||
"srv": map[string]any{
|
||||
"command": "run",
|
||||
"env": map[string]string{"TOKEN": "${TEST_MCP_TOKEN}"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, ".mcp.json")
|
||||
err := os.WriteFile(path, []byte(content), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := agentmcp.ParseConfig(path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []agentmcp.ServerConfig{
|
||||
{
|
||||
Name: "srv",
|
||||
Transport: "stdio",
|
||||
Command: "run",
|
||||
Env: map[string]string{"TOKEN": "secret123"},
|
||||
},
|
||||
}, got)
|
||||
}
|
||||
|
||||
func TestParseConfig_FileNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := agentmcp.ParseConfig(filepath.Join(t.TempDir(), "nonexistent.json"))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// mustJSON marshals v to a JSON string, failing the test on error.
|
||||
func mustJSON(t *testing.T, v any) string {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
return string(data)
|
||||
}
|
||||
@@ -0,0 +1,470 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// ToolNameSep separates the server name from the original tool name
|
||||
// in prefixed tool names. Double underscore avoids collisions with
|
||||
// tool names that may contain single underscores.
|
||||
const ToolNameSep = "__"
|
||||
|
||||
// connectTimeout bounds how long we wait for a single MCP server
|
||||
// to start its transport and complete initialization.
|
||||
const connectTimeout = 30 * time.Second
|
||||
|
||||
// toolCallTimeout bounds how long a single tool invocation may
|
||||
// take before being canceled.
|
||||
const toolCallTimeout = 60 * time.Second
|
||||
|
||||
var (
|
||||
// ErrInvalidToolName is returned when the tool name format
|
||||
// is not "server__tool".
|
||||
ErrInvalidToolName = xerrors.New("invalid tool name format")
|
||||
// ErrUnknownServer is returned when no MCP server matches
|
||||
// the prefix in the tool name.
|
||||
ErrUnknownServer = xerrors.New("unknown MCP server")
|
||||
)
|
||||
|
||||
// Manager manages connections to MCP servers discovered from a
|
||||
// workspace's .mcp.json file. It caches the aggregated tool list
|
||||
// and proxies tool calls to the appropriate server.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
logger slog.Logger
|
||||
closed bool
|
||||
servers map[string]*serverEntry // keyed by server name
|
||||
tools []workspacesdk.MCPToolInfo
|
||||
}
|
||||
|
||||
// serverEntry pairs a server config with its connected client.
|
||||
type serverEntry struct {
|
||||
config ServerConfig
|
||||
client *client.Client
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP client manager.
|
||||
func NewManager(logger slog.Logger) *Manager {
|
||||
return &Manager{
|
||||
logger: logger,
|
||||
servers: make(map[string]*serverEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect reads MCP config files at the given absolute paths and
|
||||
// connects to all configured servers. Failed servers are logged
|
||||
// and skipped. Missing config files are silently skipped.
|
||||
func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error {
|
||||
var allConfigs []ServerConfig
|
||||
for _, configPath := range mcpConfigFiles {
|
||||
configs, err := ParseConfig(configPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
continue
|
||||
}
|
||||
m.logger.Warn(ctx, "failed to parse MCP config",
|
||||
slog.F("path", configPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
allConfigs = append(allConfigs, configs...)
|
||||
}
|
||||
|
||||
// Deduplicate by server name; first occurrence wins.
|
||||
seen := make(map[string]struct{})
|
||||
deduped := make([]ServerConfig, 0, len(allConfigs))
|
||||
for _, cfg := range allConfigs {
|
||||
if _, ok := seen[cfg.Name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[cfg.Name] = struct{}{}
|
||||
deduped = append(deduped, cfg)
|
||||
}
|
||||
allConfigs = deduped
|
||||
|
||||
if len(allConfigs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect to servers in parallel without holding the
|
||||
// lock, since each connectServer call may block on
|
||||
// network I/O for up to connectTimeout.
|
||||
type connectedServer struct {
|
||||
name string
|
||||
config ServerConfig
|
||||
client *client.Client
|
||||
}
|
||||
var (
|
||||
mu sync.Mutex
|
||||
connected []connectedServer
|
||||
)
|
||||
var eg errgroup.Group
|
||||
for _, cfg := range allConfigs {
|
||||
eg.Go(func() error {
|
||||
c, err := m.connectServer(ctx, cfg)
|
||||
if err != nil {
|
||||
m.logger.Warn(ctx, "skipping MCP server",
|
||||
slog.F("server", cfg.Name),
|
||||
slog.F("transport", cfg.Transport),
|
||||
slog.Error(err),
|
||||
)
|
||||
return nil // Don't fail the group.
|
||||
}
|
||||
mu.Lock()
|
||||
connected = append(connected, connectedServer{
|
||||
name: cfg.Name, config: cfg, client: c,
|
||||
})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
_ = eg.Wait()
|
||||
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
// Close the freshly-connected clients since we're
|
||||
// shutting down.
|
||||
for _, cs := range connected {
|
||||
_ = cs.client.Close()
|
||||
}
|
||||
return xerrors.New("manager closed")
|
||||
}
|
||||
|
||||
// Close previous connections to avoid leaking child
|
||||
// processes on agent reconnect.
|
||||
for _, entry := range m.servers {
|
||||
_ = entry.client.Close()
|
||||
}
|
||||
m.servers = make(map[string]*serverEntry, len(connected))
|
||||
|
||||
for _, cs := range connected {
|
||||
m.servers[cs.name] = &serverEntry{
|
||||
config: cs.config,
|
||||
client: cs.client,
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Refresh tools outside the lock to avoid blocking
|
||||
// concurrent reads during network I/O.
|
||||
if err := m.RefreshTools(ctx); err != nil {
|
||||
m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectServer establishes a connection to a single MCP server
|
||||
// and returns the connected client. It does not modify any Manager
|
||||
// state.
|
||||
func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) {
|
||||
tr, err := createTransport(cfg)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
c := client.NewClient(tr)
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := c.Start(connectCtx); err != nil {
|
||||
_ = c.Close()
|
||||
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
_, err = c.Initialize(connectCtx, mcp.InitializeRequest{
|
||||
Params: mcp.InitializeParams{
|
||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||
ClientInfo: mcp.Implementation{
|
||||
Name: "coder-agent",
|
||||
Version: buildinfo.Version(),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// createTransport builds the mcp-go transport for a server config.
|
||||
func createTransport(cfg ServerConfig) (transport.Interface, error) {
|
||||
switch cfg.Transport {
|
||||
case "stdio":
|
||||
return transport.NewStdio(
|
||||
cfg.Command,
|
||||
buildEnv(cfg.Env),
|
||||
cfg.Args...,
|
||||
), nil
|
||||
case "http", "":
|
||||
return transport.NewStreamableHTTP(
|
||||
cfg.URL,
|
||||
transport.WithHTTPHeaders(cfg.Headers),
|
||||
)
|
||||
case "sse":
|
||||
return transport.NewSSE(
|
||||
cfg.URL,
|
||||
transport.WithHeaders(cfg.Headers),
|
||||
)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
// buildEnv merges the current process environment with explicit
|
||||
// overrides, returning the result as KEY=VALUE strings suitable
|
||||
// for the stdio transport.
|
||||
func buildEnv(explicit map[string]string) []string {
|
||||
env := os.Environ()
|
||||
if len(explicit) == 0 {
|
||||
return env
|
||||
}
|
||||
|
||||
// Index existing env so explicit keys can override in-place.
|
||||
existing := make(map[string]int, len(env))
|
||||
for i, kv := range env {
|
||||
if k, _, ok := strings.Cut(kv, "="); ok {
|
||||
existing[k] = i
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range explicit {
|
||||
entry := k + "=" + v
|
||||
if idx, ok := existing[k]; ok {
|
||||
env[idx] = entry
|
||||
} else {
|
||||
env = append(env, entry)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// Tools returns the cached tool list. Thread-safe.
|
||||
func (m *Manager) Tools() []workspacesdk.MCPToolInfo {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return slices.Clone(m.tools)
|
||||
}
|
||||
|
||||
// CallTool proxies a tool call to the appropriate MCP server.
|
||||
func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
|
||||
serverName, originalName, err := splitToolName(req.ToolName)
|
||||
if err != nil {
|
||||
return workspacesdk.CallMCPToolResponse{}, err
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
entry, ok := m.servers[serverName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return workspacesdk.CallMCPToolResponse{}, xerrors.Errorf("%w: %q", ErrUnknownServer, serverName)
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(ctx, toolCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := entry.client.CallTool(callCtx, mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Name: originalName,
|
||||
Arguments: req.Arguments,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return workspacesdk.CallMCPToolResponse{}, xerrors.Errorf("call tool %q on %q: %w", originalName, serverName, err)
|
||||
}
|
||||
|
||||
return convertResult(result), nil
|
||||
}
|
||||
|
||||
// splitToolName extracts the server name and original tool name
|
||||
// from a prefixed tool name like "server__tool".
|
||||
func splitToolName(prefixed string) (serverName, toolName string, err error) {
|
||||
server, tool, ok := strings.Cut(prefixed, ToolNameSep)
|
||||
if !ok || server == "" || tool == "" {
|
||||
return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed)
|
||||
}
|
||||
return server, tool, nil
|
||||
}
|
||||
|
||||
// convertResult translates an MCP CallToolResult into a
|
||||
// workspacesdk.CallMCPToolResponse. It iterates over content
|
||||
// items and maps each recognized type.
|
||||
func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse {
|
||||
if result == nil {
|
||||
return workspacesdk.CallMCPToolResponse{}
|
||||
}
|
||||
|
||||
var content []workspacesdk.MCPToolContent
|
||||
for _, item := range result.Content {
|
||||
switch c := item.(type) {
|
||||
case mcp.TextContent:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "text",
|
||||
Text: c.Text,
|
||||
})
|
||||
case mcp.ImageContent:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "image",
|
||||
Data: c.Data,
|
||||
MediaType: c.MIMEType,
|
||||
})
|
||||
case mcp.AudioContent:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "audio",
|
||||
Data: c.Data,
|
||||
MediaType: c.MIMEType,
|
||||
})
|
||||
case mcp.EmbeddedResource:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "resource",
|
||||
Text: fmt.Sprintf("[embedded resource: %T]", c.Resource),
|
||||
})
|
||||
case mcp.ResourceLink:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "resource",
|
||||
Text: fmt.Sprintf("[resource link: %s]", c.URI),
|
||||
})
|
||||
default:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("[unsupported content type: %T]", item),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return workspacesdk.CallMCPToolResponse{
|
||||
Content: content,
|
||||
IsError: result.IsError,
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshTools re-fetches tool lists from all connected servers
|
||||
// in parallel and rebuilds the cache. On partial failure, tools
|
||||
// from servers that responded successfully are merged with the
|
||||
// existing cached tools for servers that failed, so a single
|
||||
// dead server doesn't block updates from healthy ones.
|
||||
func (m *Manager) RefreshTools(ctx context.Context) error {
|
||||
// Snapshot servers under read lock.
|
||||
m.mu.RLock()
|
||||
servers := make(map[string]*serverEntry, len(m.servers))
|
||||
for k, v := range m.servers {
|
||||
servers[k] = v
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
// Fetch tool lists in parallel without holding any lock.
|
||||
type serverTools struct {
|
||||
name string
|
||||
tools []workspacesdk.MCPToolInfo
|
||||
}
|
||||
var (
|
||||
mu sync.Mutex
|
||||
results []serverTools
|
||||
failed []string
|
||||
errs []error
|
||||
)
|
||||
var eg errgroup.Group
|
||||
for name, entry := range servers {
|
||||
eg.Go(func() error {
|
||||
listCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
result, err := entry.client.ListTools(listCtx, mcp.ListToolsRequest{})
|
||||
cancel()
|
||||
if err != nil {
|
||||
m.logger.Warn(ctx, "failed to list tools from MCP server",
|
||||
slog.F("server", name),
|
||||
slog.Error(err),
|
||||
)
|
||||
mu.Lock()
|
||||
errs = append(errs, xerrors.Errorf("list tools from %q: %w", name, err))
|
||||
failed = append(failed, name)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
var tools []workspacesdk.MCPToolInfo
|
||||
for _, tool := range result.Tools {
|
||||
tools = append(tools, workspacesdk.MCPToolInfo{
|
||||
ServerName: name,
|
||||
Name: name + ToolNameSep + tool.Name,
|
||||
Description: tool.Description,
|
||||
Schema: tool.InputSchema.Properties,
|
||||
Required: tool.InputSchema.Required,
|
||||
})
|
||||
}
|
||||
mu.Lock()
|
||||
results = append(results, serverTools{name: name, tools: tools})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
_ = eg.Wait()
|
||||
|
||||
// Build the new tool list. For servers that failed, preserve
|
||||
// their tools from the existing cache so a single dead server
|
||||
// doesn't remove healthy tools.
|
||||
var merged []workspacesdk.MCPToolInfo
|
||||
for _, st := range results {
|
||||
merged = append(merged, st.tools...)
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
failedSet := make(map[string]struct{}, len(failed))
|
||||
for _, f := range failed {
|
||||
failedSet[f] = struct{}{}
|
||||
}
|
||||
m.mu.RLock()
|
||||
for _, t := range m.tools {
|
||||
if _, ok := failedSet[t.ServerName]; ok {
|
||||
merged = append(merged, t)
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
}
|
||||
slices.SortFunc(merged, func(a, b workspacesdk.MCPToolInfo) int {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
m.mu.Lock()
|
||||
m.tools = merged
|
||||
m.mu.Unlock()
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// Close terminates all MCP server connections and child
|
||||
// processes.
|
||||
func (m *Manager) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.closed = true
|
||||
var errs []error
|
||||
for _, entry := range m.servers {
|
||||
errs = append(errs, entry.client.Close())
|
||||
}
|
||||
m.servers = make(map[string]*serverEntry)
|
||||
m.tools = nil
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
func TestSplitToolName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantServer string
|
||||
wantTool string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid",
|
||||
input: "server__tool",
|
||||
wantServer: "server",
|
||||
wantTool: "tool",
|
||||
},
|
||||
{
|
||||
name: "ValidWithUnderscoresInTool",
|
||||
input: "server__my_tool",
|
||||
wantServer: "server",
|
||||
wantTool: "my_tool",
|
||||
},
|
||||
{
|
||||
name: "MissingSeparator",
|
||||
input: "servertool",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyServer",
|
||||
input: "__tool",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyTool",
|
||||
input: "server__",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "JustSeparator",
|
||||
input: "__",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, tool, err := splitToolName(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrInvalidToolName)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantServer, server)
|
||||
assert.Equal(t, tt.wantTool, tool)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// input is a pointer so we can test nil.
|
||||
input *mcp.CallToolResult
|
||||
want workspacesdk.CallMCPToolResponse
|
||||
}{
|
||||
{
|
||||
name: "NilInput",
|
||||
input: nil,
|
||||
want: workspacesdk.CallMCPToolResponse{},
|
||||
},
|
||||
{
|
||||
name: "TextContent",
|
||||
input: &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "hello"},
|
||||
},
|
||||
},
|
||||
want: workspacesdk.CallMCPToolResponse{
|
||||
Content: []workspacesdk.MCPToolContent{
|
||||
{Type: "text", Text: "hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ImageContent",
|
||||
input: &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: "base64data",
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: workspacesdk.CallMCPToolResponse{
|
||||
Content: []workspacesdk.MCPToolContent{
|
||||
{Type: "image", Data: "base64data", MediaType: "image/png"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "AudioContent",
|
||||
input: &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.AudioContent{
|
||||
Type: "audio",
|
||||
Data: "base64audio",
|
||||
MIMEType: "audio/mp3",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: workspacesdk.CallMCPToolResponse{
|
||||
Content: []workspacesdk.MCPToolContent{
|
||||
{Type: "audio", Data: "base64audio", MediaType: "audio/mp3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "IsErrorPropagation",
|
||||
input: &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "fail"},
|
||||
},
|
||||
IsError: true,
|
||||
},
|
||||
want: workspacesdk.CallMCPToolResponse{
|
||||
Content: []workspacesdk.MCPToolContent{
|
||||
{Type: "text", Text: "fail"},
|
||||
},
|
||||
IsError: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MultipleContentItems",
|
||||
input: &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "caption"},
|
||||
mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: "imgdata",
|
||||
MIMEType: "image/jpeg",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: workspacesdk.CallMCPToolResponse{
|
||||
Content: []workspacesdk.MCPToolContent{
|
||||
{Type: "text", Text: "caption"},
|
||||
{Type: "image", Data: "imgdata", MediaType: "image/jpeg"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ResourceLink",
|
||||
input: &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.ResourceLink{
|
||||
Type: "resource_link",
|
||||
URI: "file:///tmp/test.txt",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: workspacesdk.CallMCPToolResponse{
|
||||
Content: []workspacesdk.MCPToolContent{
|
||||
{Type: "resource", Text: "[resource link: file:///tmp/test.txt]"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := convertResult(tt.input)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -173,7 +173,10 @@ func Start(t *testing.T, inv *serpent.Invocation) {
|
||||
StartWithAssert(t, inv, nil)
|
||||
}
|
||||
|
||||
func StartWithAssert(t *testing.T, inv *serpent.Invocation, assertCallback func(t *testing.T, err error)) { //nolint:revive
|
||||
// StartWithAssert starts the given invocation and calls assertCallback
|
||||
// with the resulting error when the invocation completes. If assertCallback
|
||||
// is nil, expected shutdown errors are silently tolerated.
|
||||
func StartWithAssert(t *testing.T, inv *serpent.Invocation, assertCallback func(t *testing.T, err error)) {
|
||||
t.Helper()
|
||||
|
||||
closeCh := make(chan struct{})
|
||||
|
||||
@@ -173,7 +173,6 @@ func (selectModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:revive // The linter complains about modifying 'm' but this is typical practice for bubbletea
|
||||
func (m selectModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
@@ -463,7 +462,6 @@ func (multiSelectModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:revive // For same reason as previous Update definition
|
||||
func (m multiSelectModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
|
||||
@@ -1414,7 +1414,6 @@ func tailLineStyle() pretty.Style {
|
||||
return pretty.Style{pretty.Nop}
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func SlimUnsupported(w io.Writer, cmd string) {
|
||||
_, _ = fmt.Fprintf(w, "You are using a 'slim' build of Coder, which does not support the %s subcommand.\n", pretty.Sprint(cliui.DefaultStyles.Code, cmd))
|
||||
_, _ = fmt.Fprintln(w, "")
|
||||
|
||||
+13
-2
@@ -352,8 +352,6 @@ func TestScheduleOverride(t *testing.T) {
|
||||
require.NoError(t, err, "invalid schedule")
|
||||
ownerClient, _, _, ws := setupTestSchedule(t, sched)
|
||||
now := time.Now()
|
||||
// To avoid the likelihood of time-related flakes, only matching up to the hour.
|
||||
expectedDeadline := now.In(loc).Add(10 * time.Hour).Format("2006-01-02T15:")
|
||||
|
||||
// When: we override the stop schedule
|
||||
inv, root := clitest.New(t,
|
||||
@@ -364,6 +362,19 @@ func TestScheduleOverride(t *testing.T) {
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
require.NoError(t, inv.Run())
|
||||
|
||||
// Fetch the workspace to get the actual deadline set by the
|
||||
// server. Computing our own expected deadline from a separately
|
||||
// captured time.Now() is racy: the CLI command calls time.Now()
|
||||
// internally, and with the Asia/Kolkata +05:30 offset the hour
|
||||
// boundary falls at :30 UTC minutes. A small delay between our
|
||||
// time.Now() and the command's is enough to land in different
|
||||
// hours.
|
||||
updated, err := ownerClient.Workspace(context.Background(), ws[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, updated.LatestBuild.Deadline.IsZero(), "deadline should be set after extend")
|
||||
require.WithinDuration(t, now.Add(10*time.Hour), updated.LatestBuild.Deadline.Time, 5*time.Minute)
|
||||
expectedDeadline := updated.LatestBuild.Deadline.Time.In(loc).Format(time.RFC3339)
|
||||
|
||||
// Then: the updated schedule should be shown
|
||||
pty.ExpectMatch(ws[0].OwnerName + "/" + ws[0].Name)
|
||||
pty.ExpectMatch(sched.Humanize())
|
||||
|
||||
+2
-5
@@ -305,7 +305,6 @@ func enablePrometheus(
|
||||
}
|
||||
options.ProvisionerdServerMetrics = provisionerdserverMetrics
|
||||
|
||||
//nolint:revive
|
||||
return ServeHandler(
|
||||
ctx, logger, promhttp.InstrumentMetricHandler(
|
||||
options.PrometheusRegistry, promhttp.HandlerFor(options.PrometheusRegistry, promhttp.HandlerOpts{}),
|
||||
@@ -1637,8 +1636,6 @@ var defaultCipherSuites = func() []uint16 {
|
||||
// configureServerTLS returns the TLS config used for the Coderd server
|
||||
// connections to clients. A logger is passed in to allow printing warning
|
||||
// messages that do not block startup.
|
||||
//
|
||||
//nolint:revive
|
||||
func configureServerTLS(ctx context.Context, logger slog.Logger, tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string, ciphers []string, allowInsecureCiphers bool) (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
@@ -2055,7 +2052,6 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
//nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive)
|
||||
func configureGithubOAuth2(instrument *promoauth.Factory, params *githubOAuth2ConfigParams) (*coderd.GithubOAuth2Config, error) {
|
||||
redirectURL, err := params.accessURL.Parse("/api/v2/users/oauth2/github/callback")
|
||||
if err != nil {
|
||||
@@ -2331,7 +2327,8 @@ func ConfigureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile stri
|
||||
return ctx, nil, err
|
||||
}
|
||||
|
||||
tlsClientConfig := &tls.Config{ //nolint:gosec
|
||||
tlsClientConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: certificates,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
}
|
||||
|
||||
@@ -2123,7 +2123,6 @@ func TestServer_TelemetryDisable(t *testing.T) {
|
||||
// Set the default telemetry to true (normally disabled in tests).
|
||||
t.Setenv("CODER_TEST_TELEMETRY_DEFAULT_ENABLE", "true")
|
||||
|
||||
//nolint:paralleltest // No need to reinitialise the variable tt (Go version).
|
||||
for _, tt := range []struct {
|
||||
key string
|
||||
val string
|
||||
|
||||
@@ -165,6 +165,37 @@ func TestSyncCommands_Golden(t *testing.T) {
|
||||
clitest.TestGoldenFile(t, "TestSyncCommands_Golden/want_success", outBuf.Bytes(), nil)
|
||||
})
|
||||
|
||||
t.Run("want_multiple_deps", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
path, cleanup := setupSocketServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
inv, _ := clitest.New(t, "exp", "sync", "want", "test-unit", "dep-1", "dep-2", "dep-3", "--socket-path", path)
|
||||
inv.Stdout = &outBuf
|
||||
inv.Stderr = &outBuf
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all dependencies were registered by checking status.
|
||||
outBuf.Reset()
|
||||
inv, _ = clitest.New(t, "exp", "sync", "status", "test-unit", "--socket-path", path, "--output", "json")
|
||||
inv.Stdout = &outBuf
|
||||
inv.Stderr = &outBuf
|
||||
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// The output should mention all three dependencies.
|
||||
output := outBuf.String()
|
||||
require.Contains(t, output, "dep-1")
|
||||
require.Contains(t, output, "dep-2")
|
||||
require.Contains(t, output, "dep-3")
|
||||
})
|
||||
|
||||
t.Run("complete", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
path, cleanup := setupSocketServer(t)
|
||||
|
||||
+9
-8
@@ -11,17 +11,16 @@ import (
|
||||
|
||||
func (*RootCmd) syncWant(socketPath *string) *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "want <unit> <depends-on>",
|
||||
Short: "Declare that a unit depends on another unit completing before it can start",
|
||||
Long: "Declare that a unit depends on another unit completing before it can start. The unit specified first will not start until the second has signaled that it has completed.",
|
||||
Use: "want <unit> <depends-on> [depends-on...]",
|
||||
Short: "Declare that a unit depends on other units completing before it can start",
|
||||
Long: "Declare that a unit depends on one or more other units completing before it can start. The unit specified first will not start until all subsequent units have signaled that they have completed.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := i.Context()
|
||||
|
||||
if len(i.Args) != 2 {
|
||||
return xerrors.New("exactly two arguments are required: unit and depends-on")
|
||||
if len(i.Args) < 2 {
|
||||
return xerrors.New("at least two arguments are required: unit and one or more depends-on")
|
||||
}
|
||||
dependentUnit := unit.ID(i.Args[0])
|
||||
dependsOn := unit.ID(i.Args[1])
|
||||
|
||||
opts := []agentsocket.Option{}
|
||||
if *socketPath != "" {
|
||||
@@ -34,8 +33,10 @@ func (*RootCmd) syncWant(socketPath *string) *serpent.Command {
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
if err := client.SyncWant(ctx, dependentUnit, dependsOn); err != nil {
|
||||
return xerrors.Errorf("declare dependency failed: %w", err)
|
||||
for _, dep := range i.Args[1:] {
|
||||
if err := client.SyncWant(ctx, dependentUnit, unit.ID(dep)); err != nil {
|
||||
return xerrors.Errorf("declare dependency failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
cliui.Info(i.Stdout, "Success")
|
||||
|
||||
@@ -828,7 +828,7 @@ func TestTemplateEdit(t *testing.T) {
|
||||
"--require-active-version",
|
||||
}
|
||||
inv, root := clitest.New(t, cmdArgs...)
|
||||
//nolint
|
||||
//nolint:gocritic // Using owner client is required for template editing.
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
@@ -858,7 +858,7 @@ func TestTemplateEdit(t *testing.T) {
|
||||
"--name", "something-new",
|
||||
}
|
||||
inv, root := clitest.New(t, cmdArgs...)
|
||||
//nolint
|
||||
//nolint:gocritic // Using owner client is required for template editing.
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
+1
-1
@@ -16,7 +16,7 @@ SUBCOMMANDS:
|
||||
ping Test agent socket connectivity and health
|
||||
start Wait until all unit dependencies are satisfied
|
||||
status Show unit status and dependency state
|
||||
want Declare that a unit depends on another unit completing before it
|
||||
want Declare that a unit depends on other units completing before it
|
||||
can start
|
||||
|
||||
OPTIONS:
|
||||
|
||||
+5
-5
@@ -1,13 +1,13 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder exp sync want <unit> <depends-on>
|
||||
coder exp sync want <unit> <depends-on> [depends-on...]
|
||||
|
||||
Declare that a unit depends on another unit completing before it can start
|
||||
Declare that a unit depends on other units completing before it can start
|
||||
|
||||
Declare that a unit depends on another unit completing before it can start.
|
||||
The unit specified first will not start until the second has signaled that it
|
||||
has completed.
|
||||
Declare that a unit depends on one or more other units completing before it
|
||||
can start. The unit specified first will not start until all subsequent units
|
||||
have signaled that they have completed.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
+4
-2
@@ -17,7 +17,8 @@
|
||||
"name": "owner",
|
||||
"display_name": "Owner"
|
||||
}
|
||||
]
|
||||
],
|
||||
"has_ai_seat": false
|
||||
},
|
||||
{
|
||||
"id": "==========[second user ID]==========",
|
||||
@@ -31,6 +32,7 @@
|
||||
"organization_ids": [
|
||||
"===========[first org ID]==========="
|
||||
],
|
||||
"roles": []
|
||||
"roles": [],
|
||||
"has_ai_seat": false
|
||||
}
|
||||
]
|
||||
|
||||
+7
-2
@@ -857,13 +857,18 @@ aibridgeproxy:
|
||||
# Comma-separated list of AI provider domains for which HTTPS traffic will be
|
||||
# decrypted and routed through AI Bridge. Requests to other domains will be
|
||||
# tunneled directly without decryption. Supported domains: api.anthropic.com,
|
||||
# api.openai.com, api.individual.githubcopilot.com.
|
||||
# (default: api.anthropic.com,api.openai.com,api.individual.githubcopilot.com,
|
||||
# api.openai.com, api.individual.githubcopilot.com,
|
||||
# api.business.githubcopilot.com, api.enterprise.githubcopilot.com, chatgpt.com.
|
||||
# (default:
|
||||
# api.anthropic.com,api.openai.com,api.individual.githubcopilot.com,api.business.githubcopilot.com,api.enterprise.githubcopilot.com,chatgpt.com,
|
||||
# type: string-array)
|
||||
domain_allowlist:
|
||||
- api.anthropic.com
|
||||
- api.openai.com
|
||||
- api.individual.githubcopilot.com
|
||||
- api.business.githubcopilot.com
|
||||
- api.enterprise.githubcopilot.com
|
||||
- chatgpt.com
|
||||
# URL of an upstream HTTP proxy to chain tunneled (non-allowlisted) requests
|
||||
# through. Format: http://[user:pass@]host:port or https://[user:pass@]host:port.
|
||||
# (default: <unset>, type: string)
|
||||
|
||||
@@ -101,7 +101,6 @@ func TestConnectionLog(t *testing.T) {
|
||||
reason: "because error says so",
|
||||
},
|
||||
}
|
||||
//nolint:paralleltest // No longer necessary to reinitialise the variable tt.
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -3,6 +3,7 @@ package agentapi
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -60,6 +61,8 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
||||
}
|
||||
)
|
||||
for _, md := range req.Metadata {
|
||||
md.Result.Value = strings.TrimSpace(md.Result.Value)
|
||||
md.Result.Error = strings.TrimSpace(md.Result.Error)
|
||||
metadataError := md.Result.Error
|
||||
|
||||
allKeysLen += len(md.Key)
|
||||
|
||||
@@ -57,16 +57,44 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
CollectedAt: timestamppb.New(now.Add(-3 * time.Second)),
|
||||
Age: 3,
|
||||
Value: "",
|
||||
Error: "uncool value",
|
||||
Error: "\t uncool error ",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
batchSize := len(req.Metadata)
|
||||
// This test sends 2 metadata entries. With batch size 2, we expect
|
||||
// exactly 1 capacity flush.
|
||||
// This test sends 2 metadata entries (one clean, one with
|
||||
// whitespace padding). With batch size 2 we expect exactly
|
||||
// 1 capacity flush. The matcher verifies that stored values
|
||||
// are trimmed while clean values pass through unchanged.
|
||||
expectedValues := map[string]string{
|
||||
"awesome key": "awesome value",
|
||||
"uncool key": "",
|
||||
}
|
||||
expectedErrors := map[string]string{
|
||||
"awesome key": "",
|
||||
"uncool key": "uncool error",
|
||||
}
|
||||
store.EXPECT().
|
||||
BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).
|
||||
BatchUpdateWorkspaceAgentMetadata(
|
||||
gomock.Any(),
|
||||
gomock.Cond(func(arg database.BatchUpdateWorkspaceAgentMetadataParams) bool {
|
||||
if len(arg.Key) != len(expectedValues) {
|
||||
return false
|
||||
}
|
||||
for i, key := range arg.Key {
|
||||
expVal, ok := expectedValues[key]
|
||||
if !ok || arg.Value[i] != expVal {
|
||||
return false
|
||||
}
|
||||
expErr, ok := expectedErrors[key]
|
||||
if !ok || arg.Error[i] != expErr {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}),
|
||||
).
|
||||
Return(nil).
|
||||
Times(1)
|
||||
|
||||
|
||||
@@ -16,6 +16,25 @@ import (
|
||||
// that use per-user LLM credentials but cannot set custom headers.
|
||||
const HeaderCoderToken = "X-Coder-AI-Governance-Token" //nolint:gosec // This is a header name, not a credential.
|
||||
|
||||
// HeaderCoderRequestID is a header set by aibridgeproxyd on each
|
||||
// request forwarded to aibridged for cross-service log correlation.
|
||||
const HeaderCoderRequestID = "X-Coder-AI-Governance-Request-Id"
|
||||
|
||||
// Copilot provider.
|
||||
const (
|
||||
ProviderCopilotBusiness = "copilot-business"
|
||||
HostCopilotBusiness = "api.business.githubcopilot.com"
|
||||
ProviderCopilotEnterprise = "copilot-enterprise"
|
||||
HostCopilotEnterprise = "api.enterprise.githubcopilot.com"
|
||||
)
|
||||
|
||||
// ChatGPT provider.
|
||||
const (
|
||||
ProviderChatGPT = "chatgpt"
|
||||
HostChatGPT = "chatgpt.com"
|
||||
BaseURLChatGPT = "https://" + HostChatGPT + "/backend-api/codex"
|
||||
)
|
||||
|
||||
// IsBYOK reports whether the request is using BYOK mode, determined
|
||||
// by the presence of the X-Coder-AI-Governance-Token header.
|
||||
func IsBYOK(header http.Header) bool {
|
||||
|
||||
Generated
+267
@@ -84,6 +84,34 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/aibridge/clients": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"AI Bridge"
|
||||
],
|
||||
"summary": "List AI Bridge clients",
|
||||
"operationId": "list-ai-bridge-clients",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/interceptions": {
|
||||
"get": {
|
||||
"produces": [
|
||||
@@ -214,6 +242,58 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/sessions/{session_id}": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"AI Bridge"
|
||||
],
|
||||
"summary": "Get AI Bridge session threads",
|
||||
"operationId": "get-ai-bridge-session-threads",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Session ID (client_session_id or interception UUID)",
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Thread pagination cursor (forward/older)",
|
||||
"name": "after_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Thread pagination cursor (backward/newer)",
|
||||
"name": "before_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Number of threads per page (default 50)",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"produces": [
|
||||
@@ -12675,6 +12755,29 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeAgenticAction": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"thinking": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeModelThought"
|
||||
}
|
||||
},
|
||||
"token_usage": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
|
||||
},
|
||||
"tool_calls": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeToolCall"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeAnthropicConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -12791,6 +12894,9 @@ const docTemplate = `{
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -12843,6 +12949,14 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeModelThought": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeOpenAIConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -12942,6 +13056,76 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionThreadsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"client": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"page_ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"page_started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"threads": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeThread"
|
||||
}
|
||||
},
|
||||
"token_usage_summary": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionThreadsTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -12953,6 +13137,41 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeThread": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agentic_actions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"token_usage": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -12983,6 +13202,42 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeToolCall": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"injected": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"input": {
|
||||
"type": "string"
|
||||
},
|
||||
"interception_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"provider_response_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"server_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeToolUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -17426,6 +17681,10 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/codersdk.SlimRole"
|
||||
}
|
||||
},
|
||||
"has_ai_seat": {
|
||||
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -20222,6 +20481,10 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "email"
|
||||
},
|
||||
"has_ai_seat": {
|
||||
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -21071,6 +21334,10 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "email"
|
||||
},
|
||||
"has_ai_seat": {
|
||||
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
|
||||
Generated
+259
@@ -65,6 +65,30 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/aibridge/clients": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["AI Bridge"],
|
||||
"summary": "List AI Bridge clients",
|
||||
"operationId": "list-ai-bridge-clients",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/interceptions": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
@@ -183,6 +207,54 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/sessions/{session_id}": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["AI Bridge"],
|
||||
"summary": "Get AI Bridge session threads",
|
||||
"operationId": "get-ai-bridge-session-threads",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Session ID (client_session_id or interception UUID)",
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Thread pagination cursor (forward/older)",
|
||||
"name": "after_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Thread pagination cursor (backward/newer)",
|
||||
"name": "before_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Number of threads per page (default 50)",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
@@ -11261,6 +11333,29 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeAgenticAction": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"thinking": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeModelThought"
|
||||
}
|
||||
},
|
||||
"token_usage": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
|
||||
},
|
||||
"tool_calls": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeToolCall"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeAnthropicConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -11377,6 +11472,9 @@
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -11429,6 +11527,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeModelThought": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeOpenAIConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -11528,6 +11634,76 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionThreadsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"client": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"page_ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"page_started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"threads": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeThread"
|
||||
}
|
||||
},
|
||||
"token_usage_summary": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionThreadsTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -11539,6 +11715,41 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeThread": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agentic_actions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"token_usage": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -11569,6 +11780,42 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeToolCall": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"injected": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"input": {
|
||||
"type": "string"
|
||||
},
|
||||
"interception_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"provider_response_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"server_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeToolUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -15851,6 +16098,10 @@
|
||||
"$ref": "#/definitions/codersdk.SlimRole"
|
||||
}
|
||||
},
|
||||
"has_ai_seat": {
|
||||
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -18547,6 +18798,10 @@
|
||||
"type": "string",
|
||||
"format": "email"
|
||||
},
|
||||
"has_ai_seat": {
|
||||
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -19339,6 +19594,10 @@
|
||||
"type": "string",
|
||||
"format": "email"
|
||||
},
|
||||
"has_ai_seat": {
|
||||
"description": "HasAISeat intentionally omits omitempty so the API always includes the\nfield, even when false.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
|
||||
@@ -582,5 +582,20 @@ func (api *API) createAPIKey(ctx context.Context, params apikey.CreateParams) (*
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
// MaxAge is set so the browser persists the cookie to disk rather
|
||||
// than keeping it in memory as a session cookie. Standalone PWAs
|
||||
// (display: standalone) run in their own browser process, and
|
||||
// mobile OSes kill that process when the app is swiped away —
|
||||
// deleting in-memory cookies and forcing an unexpected login.
|
||||
//
|
||||
// We use a long static value (1 year) instead of the key's
|
||||
// LifetimeSeconds because the server refreshes the key's
|
||||
// ExpiresAt on activity but does not re-set the cookie. Tying
|
||||
// MaxAge to the key lifetime would cause the cookie to expire
|
||||
// client-side even when the server-side key is still valid.
|
||||
//
|
||||
// Security is not affected: the server validates ExpiresAt on
|
||||
// every request regardless of the cookie's MaxAge.
|
||||
MaxAge: int((365 * 24 * time.Hour).Seconds()),
|
||||
}), &newkey, nil
|
||||
}
|
||||
|
||||
@@ -394,6 +394,55 @@ func TestSessionExpiry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionCookieMaxAge verifies that the session cookie is a persistent
|
||||
// cookie (has MaxAge set) rather than a session cookie. Standalone PWAs
|
||||
// run in their own browser process and mobile OSes purge in-memory
|
||||
// (session) cookies when that process is killed, so the cookie must be
|
||||
// persisted to disk.
|
||||
func TestSessionCookieMaxAge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
|
||||
// Create the first user (password-based login).
|
||||
req := codersdk.CreateFirstUserRequest{
|
||||
Email: "testuser@coder.com",
|
||||
Username: "testuser",
|
||||
Password: "SomeSecurePassword!",
|
||||
}
|
||||
_, err := client.CreateFirstUser(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Login via the raw HTTP endpoint so we can inspect the Set-Cookie header.
|
||||
loginURL, err := client.URL.Parse("/api/v2/users/login")
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.Request(ctx, http.MethodPost, loginURL.String(), codersdk.LoginWithPasswordRequest{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusCreated, res.StatusCode)
|
||||
|
||||
oneYear := int((365 * 24 * time.Hour).Seconds())
|
||||
var found bool
|
||||
for _, cookie := range res.Cookies() {
|
||||
if cookie.Name == codersdk.SessionTokenCookie {
|
||||
// MaxAge should be set to a long value so the browser
|
||||
// persists the cookie to disk. The server handles real
|
||||
// expiry via the API key's ExpiresAt field.
|
||||
require.Equal(t, oneYear, cookie.MaxAge,
|
||||
"Session cookie MaxAge should be set to 1 year for disk persistence")
|
||||
found = true
|
||||
}
|
||||
}
|
||||
require.True(t, found, "session cookie should be present in login response")
|
||||
}
|
||||
|
||||
func TestAPIKey_OK(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+1
-1
@@ -220,7 +220,7 @@ func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) {
|
||||
Type: string(v.Object.ResourceType),
|
||||
AnyOrgOwner: v.Object.AnyOrgOwner,
|
||||
}
|
||||
if obj.Owner == "me" {
|
||||
if obj.Owner == codersdk.Me {
|
||||
obj.Owner = auth.ID
|
||||
}
|
||||
|
||||
|
||||
@@ -1155,6 +1155,7 @@ func New(options *Options) *API {
|
||||
apiKeyMiddleware,
|
||||
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents),
|
||||
)
|
||||
r.Get("/by-workspace", api.chatsByWorkspace)
|
||||
r.Get("/", api.listChats)
|
||||
r.Post("/", api.postChats)
|
||||
r.Get("/models", api.listChatModels)
|
||||
@@ -1233,6 +1234,7 @@ func New(options *Options) *API {
|
||||
r.Get("/git", api.watchChatGit)
|
||||
})
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Post("/title/regenerate", api.regenerateChatTitle)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
r.Delete("/", api.deleteChatQueuedMessage)
|
||||
|
||||
@@ -999,15 +999,16 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
|
||||
return sdkToolUsages[i].CreatedAt.Before(sdkToolUsages[j].CreatedAt)
|
||||
})
|
||||
intc := codersdk.AIBridgeInterception{
|
||||
ID: interception.ID,
|
||||
Initiator: MinimalUserFromVisibleUser(initiator),
|
||||
Provider: interception.Provider,
|
||||
Model: interception.Model,
|
||||
Metadata: jsonOrEmptyMap(interception.Metadata),
|
||||
StartedAt: interception.StartedAt,
|
||||
TokenUsages: sdkTokenUsages,
|
||||
UserPrompts: sdkUserPrompts,
|
||||
ToolUsages: sdkToolUsages,
|
||||
ID: interception.ID,
|
||||
Initiator: MinimalUserFromVisibleUser(initiator),
|
||||
Provider: interception.Provider,
|
||||
ProviderName: interception.ProviderName,
|
||||
Model: interception.Model,
|
||||
Metadata: jsonOrEmptyMap(interception.Metadata),
|
||||
StartedAt: interception.StartedAt,
|
||||
TokenUsages: sdkTokenUsages,
|
||||
UserPrompts: sdkUserPrompts,
|
||||
ToolUsages: sdkToolUsages,
|
||||
}
|
||||
if interception.APIKeyID.Valid {
|
||||
intc.APIKeyID = &interception.APIKeyID.String
|
||||
@@ -1097,6 +1098,287 @@ func AIBridgeToolUsage(usage database.AIBridgeToolUsage) codersdk.AIBridgeToolUs
|
||||
}
|
||||
}
|
||||
|
||||
// AIBridgeSessionThreads converts session metadata and thread interceptions
|
||||
// into the threads response. It groups interceptions into threads, builds
|
||||
// agentic actions from tool usages and model thoughts, and aggregates
|
||||
// token usage with metadata.
|
||||
func AIBridgeSessionThreads(
|
||||
session database.ListAIBridgeSessionsRow,
|
||||
interceptions []database.ListAIBridgeSessionThreadsRow,
|
||||
tokenUsages []database.AIBridgeTokenUsage,
|
||||
toolUsages []database.AIBridgeToolUsage,
|
||||
userPrompts []database.AIBridgeUserPrompt,
|
||||
modelThoughts []database.AIBridgeModelThought,
|
||||
) codersdk.AIBridgeSessionThreadsResponse {
|
||||
// Index subresources by interception ID.
|
||||
tokensByInterception := make(map[uuid.UUID][]database.AIBridgeTokenUsage, len(interceptions))
|
||||
for _, tu := range tokenUsages {
|
||||
tokensByInterception[tu.InterceptionID] = append(tokensByInterception[tu.InterceptionID], tu)
|
||||
}
|
||||
toolsByInterception := make(map[uuid.UUID][]database.AIBridgeToolUsage, len(interceptions))
|
||||
for _, tu := range toolUsages {
|
||||
toolsByInterception[tu.InterceptionID] = append(toolsByInterception[tu.InterceptionID], tu)
|
||||
}
|
||||
promptsByInterception := make(map[uuid.UUID][]database.AIBridgeUserPrompt, len(interceptions))
|
||||
for _, up := range userPrompts {
|
||||
promptsByInterception[up.InterceptionID] = append(promptsByInterception[up.InterceptionID], up)
|
||||
}
|
||||
thoughtsByInterception := make(map[uuid.UUID][]database.AIBridgeModelThought, len(interceptions))
|
||||
for _, mt := range modelThoughts {
|
||||
thoughtsByInterception[mt.InterceptionID] = append(thoughtsByInterception[mt.InterceptionID], mt)
|
||||
}
|
||||
|
||||
// Group interceptions by thread_id, preserving the order returned by the
|
||||
// SQL query.
|
||||
interceptionsByThread := make(map[uuid.UUID][]database.AIBridgeInterception, len(interceptions))
|
||||
var threadIDs []uuid.UUID
|
||||
for _, row := range interceptions {
|
||||
if _, ok := interceptionsByThread[row.ThreadID]; !ok {
|
||||
threadIDs = append(threadIDs, row.ThreadID)
|
||||
}
|
||||
interceptionsByThread[row.ThreadID] = append(interceptionsByThread[row.ThreadID], row.AIBridgeInterception)
|
||||
}
|
||||
|
||||
// Build threads and track page time bounds.
|
||||
threads := make([]codersdk.AIBridgeThread, 0, len(threadIDs))
|
||||
var pageStartedAt, pageEndedAt *time.Time
|
||||
for _, threadID := range threadIDs {
|
||||
intcs := interceptionsByThread[threadID]
|
||||
thread := buildAIBridgeThread(threadID, intcs, tokensByInterception, toolsByInterception, promptsByInterception, thoughtsByInterception)
|
||||
for _, intc := range intcs {
|
||||
if pageStartedAt == nil || intc.StartedAt.Before(*pageStartedAt) {
|
||||
t := intc.StartedAt
|
||||
pageStartedAt = &t
|
||||
}
|
||||
if intc.EndedAt.Valid {
|
||||
if pageEndedAt == nil || intc.EndedAt.Time.After(*pageEndedAt) {
|
||||
t := intc.EndedAt.Time
|
||||
pageEndedAt = &t
|
||||
}
|
||||
}
|
||||
}
|
||||
threads = append(threads, thread)
|
||||
}
|
||||
|
||||
// Aggregate session-level token usage metadata from all token
|
||||
// usages in the session (not just the page).
|
||||
sessionTokenMeta := aggregateTokenMetadata(tokenUsages)
|
||||
|
||||
resp := codersdk.AIBridgeSessionThreadsResponse{
|
||||
ID: session.SessionID,
|
||||
Initiator: MinimalUserFromVisibleUser(database.VisibleUser{
|
||||
ID: session.UserID,
|
||||
Username: session.UserUsername,
|
||||
Name: session.UserName,
|
||||
AvatarURL: session.UserAvatarUrl,
|
||||
}),
|
||||
Providers: session.Providers,
|
||||
Models: session.Models,
|
||||
Metadata: jsonOrEmptyMap(pqtype.NullRawMessage{RawMessage: session.Metadata, Valid: len(session.Metadata) > 0}),
|
||||
StartedAt: session.StartedAt,
|
||||
PageStartedAt: pageStartedAt,
|
||||
PageEndedAt: pageEndedAt,
|
||||
TokenUsageSummary: codersdk.AIBridgeSessionThreadsTokenUsage{
|
||||
InputTokens: session.InputTokens,
|
||||
OutputTokens: session.OutputTokens,
|
||||
Metadata: sessionTokenMeta,
|
||||
},
|
||||
Threads: threads,
|
||||
}
|
||||
if resp.Providers == nil {
|
||||
resp.Providers = []string{}
|
||||
}
|
||||
if resp.Models == nil {
|
||||
resp.Models = []string{}
|
||||
}
|
||||
if session.Client != "" {
|
||||
resp.Client = &session.Client
|
||||
}
|
||||
if !session.EndedAt.IsZero() {
|
||||
resp.EndedAt = &session.EndedAt
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func buildAIBridgeThread(
|
||||
threadID uuid.UUID,
|
||||
interceptions []database.AIBridgeInterception,
|
||||
tokensByInterception map[uuid.UUID][]database.AIBridgeTokenUsage,
|
||||
toolsByInterception map[uuid.UUID][]database.AIBridgeToolUsage,
|
||||
promptsByInterception map[uuid.UUID][]database.AIBridgeUserPrompt,
|
||||
thoughtsByInterception map[uuid.UUID][]database.AIBridgeModelThought,
|
||||
) codersdk.AIBridgeThread {
|
||||
// Find the root interception (where id == threadID) to get the
|
||||
// thread prompt and model.
|
||||
var rootIntc *database.AIBridgeInterception
|
||||
for i := range interceptions {
|
||||
if interceptions[i].ID == threadID {
|
||||
rootIntc = &interceptions[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
// Fallback to first interception if root not found.
|
||||
if rootIntc == nil && len(interceptions) > 0 {
|
||||
rootIntc = &interceptions[0]
|
||||
}
|
||||
|
||||
thread := codersdk.AIBridgeThread{
|
||||
ID: threadID,
|
||||
}
|
||||
if rootIntc != nil {
|
||||
thread.Model = rootIntc.Model
|
||||
thread.Provider = rootIntc.Provider
|
||||
// Get first user prompt from root interception.
|
||||
// A thread can only have one prompt, by definition, since we currently
|
||||
// only store the last prompt observed in an interception.
|
||||
if prompts := promptsByInterception[rootIntc.ID]; len(prompts) > 0 {
|
||||
thread.Prompt = &prompts[0].Prompt
|
||||
}
|
||||
}
|
||||
|
||||
// Compute thread time bounds from interceptions.
|
||||
for _, intc := range interceptions {
|
||||
if thread.StartedAt.IsZero() || intc.StartedAt.Before(thread.StartedAt) {
|
||||
thread.StartedAt = intc.StartedAt
|
||||
}
|
||||
if intc.EndedAt.Valid {
|
||||
if thread.EndedAt == nil || intc.EndedAt.Time.After(*thread.EndedAt) {
|
||||
t := intc.EndedAt.Time
|
||||
thread.EndedAt = &t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build agentic actions grouped by interception. Each interception that
|
||||
// has tool calls produces one action with all its tool calls, thinking
|
||||
// blocks, and token usage.
|
||||
var actions []codersdk.AIBridgeAgenticAction
|
||||
for _, intc := range interceptions {
|
||||
tools := toolsByInterception[intc.ID]
|
||||
if len(tools) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Thinking blocks for this interception.
|
||||
thoughts := thoughtsByInterception[intc.ID]
|
||||
thinking := make([]codersdk.AIBridgeModelThought, 0, len(thoughts))
|
||||
for _, mt := range thoughts {
|
||||
thinking = append(thinking, codersdk.AIBridgeModelThought{
|
||||
Text: mt.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Token usage for the interception.
|
||||
actionTokenUsage := aggregateTokenUsage(tokensByInterception[intc.ID])
|
||||
|
||||
// Build tool call list.
|
||||
toolCalls := make([]codersdk.AIBridgeToolCall, 0, len(tools))
|
||||
for _, tu := range tools {
|
||||
toolCalls = append(toolCalls, codersdk.AIBridgeToolCall{
|
||||
ID: tu.ID,
|
||||
InterceptionID: tu.InterceptionID,
|
||||
ProviderResponseID: tu.ProviderResponseID,
|
||||
ServerURL: tu.ServerUrl.String,
|
||||
Tool: tu.Tool,
|
||||
Injected: tu.Injected,
|
||||
Input: tu.Input,
|
||||
Metadata: jsonOrEmptyMap(tu.Metadata),
|
||||
CreatedAt: tu.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
actions = append(actions, codersdk.AIBridgeAgenticAction{
|
||||
Model: intc.Model,
|
||||
TokenUsage: actionTokenUsage,
|
||||
Thinking: thinking,
|
||||
ToolCalls: toolCalls,
|
||||
})
|
||||
}
|
||||
|
||||
if actions == nil {
|
||||
// Make an empty slice so we don't serialize `null`.
|
||||
actions = make([]codersdk.AIBridgeAgenticAction, 0)
|
||||
}
|
||||
|
||||
thread.AgenticActions = actions
|
||||
|
||||
// Aggregate thread-level token usage.
|
||||
var threadTokens []database.AIBridgeTokenUsage
|
||||
for _, intc := range interceptions {
|
||||
threadTokens = append(threadTokens, tokensByInterception[intc.ID]...)
|
||||
}
|
||||
thread.TokenUsage = aggregateTokenUsage(threadTokens)
|
||||
|
||||
return thread
|
||||
}
|
||||
|
||||
// aggregateTokenUsage sums token usage rows and aggregates metadata.
|
||||
func aggregateTokenUsage(tokens []database.AIBridgeTokenUsage) codersdk.AIBridgeSessionThreadsTokenUsage {
|
||||
var inputTokens, outputTokens int64
|
||||
for _, tu := range tokens {
|
||||
inputTokens += tu.InputTokens
|
||||
outputTokens += tu.OutputTokens
|
||||
// TODO: once https://github.com/coder/aibridge/issues/150 lands we
|
||||
// should aggregate the other token types.
|
||||
}
|
||||
return codersdk.AIBridgeSessionThreadsTokenUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
Metadata: aggregateTokenMetadata(tokens),
|
||||
}
|
||||
}
|
||||
|
||||
// aggregateTokenMetadata sums all numeric values from the metadata
|
||||
// JSONB across the given token usage rows by key. Nested objects are
|
||||
// flattened using dot-notation (e.g. {"cache": {"read_tokens": 10}}
|
||||
// becomes "cache.read_tokens"). Non-numeric leaves (strings,
|
||||
// booleans, arrays, nulls) are silently skipped.
|
||||
func aggregateTokenMetadata(tokens []database.AIBridgeTokenUsage) map[string]any {
|
||||
sums := make(map[string]int64)
|
||||
for _, tu := range tokens {
|
||||
if !tu.Metadata.Valid || len(tu.Metadata.RawMessage) == 0 {
|
||||
continue
|
||||
}
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(tu.Metadata.RawMessage, &m); err != nil {
|
||||
continue
|
||||
}
|
||||
flattenAndSum(sums, "", m)
|
||||
}
|
||||
result := make(map[string]any, len(sums))
|
||||
for k, v := range sums {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// flattenAndSum recursively walks a JSON object and sums all numeric
|
||||
// leaf values into sums, using dot-separated keys for nested objects.
|
||||
func flattenAndSum(sums map[string]int64, prefix string, m map[string]json.RawMessage) {
|
||||
for k, raw := range m {
|
||||
key := k
|
||||
if prefix != "" {
|
||||
key = prefix + "." + k
|
||||
}
|
||||
|
||||
// Try as a number first.
|
||||
var n json.Number
|
||||
if err := json.Unmarshal(raw, &n); err == nil {
|
||||
if v, err := n.Int64(); err == nil {
|
||||
sums[key] += v
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Try as a nested object.
|
||||
var nested map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &nested); err == nil {
|
||||
flattenAndSum(sums, key, nested)
|
||||
}
|
||||
// Arrays, strings, booleans, nulls are skipped.
|
||||
}
|
||||
}
|
||||
|
||||
func InvalidatedPresets(invalidatedPresets []database.UpdatePresetsLastInvalidatedAtRow) []codersdk.InvalidatedPreset {
|
||||
var presets []codersdk.InvalidatedPreset
|
||||
for _, p := range invalidatedPresets {
|
||||
@@ -1235,6 +1517,98 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
|
||||
// nil slices and maps to empty values for JSON serialization and
|
||||
// derives RootChatID from the parent chain when not explicitly set.
|
||||
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
mcpServerIDs := c.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
labels := map[string]string(c.Labels)
|
||||
if labels == nil {
|
||||
labels = map[string]string{}
|
||||
}
|
||||
chat := codersdk.Chat{
|
||||
ID: c.ID,
|
||||
OwnerID: c.OwnerID,
|
||||
LastModelConfigID: c.LastModelConfigID,
|
||||
Title: c.Title,
|
||||
Status: codersdk.ChatStatus(c.Status),
|
||||
Archived: c.Archived,
|
||||
PinOrder: c.PinOrder,
|
||||
CreatedAt: c.CreatedAt,
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
}
|
||||
if c.LastError.Valid {
|
||||
chat.LastError = &c.LastError.String
|
||||
}
|
||||
if c.ParentChatID.Valid {
|
||||
parentChatID := c.ParentChatID.UUID
|
||||
chat.ParentChatID = &parentChatID
|
||||
}
|
||||
switch {
|
||||
case c.RootChatID.Valid:
|
||||
rootChatID := c.RootChatID.UUID
|
||||
chat.RootChatID = &rootChatID
|
||||
case c.ParentChatID.Valid:
|
||||
rootChatID := c.ParentChatID.UUID
|
||||
chat.RootChatID = &rootChatID
|
||||
default:
|
||||
rootChatID := c.ID
|
||||
chat.RootChatID = &rootChatID
|
||||
}
|
||||
if c.WorkspaceID.Valid {
|
||||
chat.WorkspaceID = &c.WorkspaceID.UUID
|
||||
}
|
||||
if c.BuildID.Valid {
|
||||
chat.BuildID = &c.BuildID.UUID
|
||||
}
|
||||
if c.AgentID.Valid {
|
||||
chat.AgentID = &c.AgentID.UUID
|
||||
}
|
||||
if diffStatus != nil {
|
||||
convertedDiffStatus := ChatDiffStatus(c.ID, diffStatus)
|
||||
chat.DiffStatus = &convertedDiffStatus
|
||||
}
|
||||
if c.LastInjectedContext.Valid {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
// Internal fields are stripped at write time in
|
||||
// chatd.updateLastInjectedContext, so no
|
||||
// StripInternal call is needed here. Unmarshal
|
||||
// errors are suppressed — the column is written by
|
||||
// us with a known schema.
|
||||
if err := json.Unmarshal(c.LastInjectedContext.RawMessage, &parts); err == nil {
|
||||
chat.LastInjectedContext = parts
|
||||
}
|
||||
}
|
||||
return chat
|
||||
}
|
||||
|
||||
// ChatRows converts a slice of database.GetChatsRow (which embeds
|
||||
// Chat plus HasUnread) to codersdk.Chat, looking up diff statuses
|
||||
// from the provided map. When diffStatusesByChatID is non-nil,
|
||||
// chats without an entry receive an empty DiffStatus.
|
||||
func ChatRows(rows []database.GetChatsRow, diffStatusesByChatID map[uuid.UUID]database.ChatDiffStatus) []codersdk.Chat {
|
||||
result := make([]codersdk.Chat, len(rows))
|
||||
for i, row := range rows {
|
||||
diffStatus, ok := diffStatusesByChatID[row.Chat.ID]
|
||||
if ok {
|
||||
result[i] = Chat(row.Chat, &diffStatus)
|
||||
} else {
|
||||
result[i] = Chat(row.Chat, nil)
|
||||
if diffStatusesByChatID != nil {
|
||||
emptyDiffStatus := ChatDiffStatus(row.Chat.ID, nil)
|
||||
result[i].DiffStatus = &emptyDiffStatus
|
||||
}
|
||||
}
|
||||
result[i].HasUnread = row.HasUnread
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ChatDiffStatus converts a database.ChatDiffStatus to a
|
||||
// codersdk.ChatDiffStatus. When status is nil an empty value
|
||||
// containing only the chatID is returned.
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
package db2sdk
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
func TestAggregateTokenMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_input", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := aggregateTokenMetadata(nil)
|
||||
require.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("sums_across_rows", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"cache_read_tokens":100,"reasoning_tokens":50}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"cache_read_tokens":200,"reasoning_tokens":75}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenMetadata(tokens)
|
||||
require.Equal(t, int64(300), result["cache_read_tokens"])
|
||||
require.Equal(t, int64(125), result["reasoning_tokens"])
|
||||
require.Len(t, result, 2)
|
||||
})
|
||||
|
||||
t.Run("skips_null_and_invalid_metadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{Valid: false},
|
||||
},
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: nil,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"tokens":42}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenMetadata(tokens)
|
||||
require.Equal(t, int64(42), result["tokens"])
|
||||
require.Len(t, result, 1)
|
||||
})
|
||||
|
||||
t.Run("skips_non_integer_values", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
// Float values fail json.Number.Int64(), so they
|
||||
// are silently dropped.
|
||||
RawMessage: json.RawMessage(`{"good":10,"fractional":1.5}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenMetadata(tokens)
|
||||
require.Equal(t, int64(10), result["good"])
|
||||
_, hasFractional := result["fractional"]
|
||||
require.False(t, hasFractional)
|
||||
})
|
||||
|
||||
t.Run("skips_malformed_json", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`not json`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"tokens":5}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenMetadata(tokens)
|
||||
// The malformed row is skipped, the valid one is counted.
|
||||
require.Equal(t, int64(5), result["tokens"])
|
||||
require.Len(t, result, 1)
|
||||
})
|
||||
|
||||
t.Run("flattens_nested_objects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{
|
||||
"cache_read_tokens": 100,
|
||||
"cache": {"creation_tokens": 40, "read_tokens": 60},
|
||||
"reasoning_tokens": 50,
|
||||
"tags": ["a", "b"]
|
||||
}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{
|
||||
"cache_read_tokens": 200,
|
||||
"cache": {"creation_tokens": 10}
|
||||
}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenMetadata(tokens)
|
||||
require.Equal(t, int64(300), result["cache_read_tokens"])
|
||||
require.Equal(t, int64(50), result["reasoning_tokens"])
|
||||
require.Equal(t, int64(50), result["cache.creation_tokens"])
|
||||
require.Equal(t, int64(60), result["cache.read_tokens"])
|
||||
// Arrays are skipped.
|
||||
_, hasTags := result["tags"]
|
||||
require.False(t, hasTags)
|
||||
require.Len(t, result, 4)
|
||||
})
|
||||
|
||||
t.Run("flattens_deeply_nested_objects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{
|
||||
"provider": {
|
||||
"anthropic": {"cache_creation_tokens": 100, "cache_read_tokens": 200},
|
||||
"openai": {"reasoning_tokens": 50}
|
||||
},
|
||||
"total": 500
|
||||
}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenMetadata(tokens)
|
||||
require.Equal(t, int64(100), result["provider.anthropic.cache_creation_tokens"])
|
||||
require.Equal(t, int64(200), result["provider.anthropic.cache_read_tokens"])
|
||||
require.Equal(t, int64(50), result["provider.openai.reasoning_tokens"])
|
||||
require.Equal(t, int64(500), result["total"])
|
||||
require.Len(t, result, 4)
|
||||
})
|
||||
|
||||
// Real-world provider metadata shapes from
|
||||
// https://github.com/coder/aibridge/issues/150.
|
||||
t.Run("aggregates_real_provider_metadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
// Anthropic-style: cache fields are top-level.
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 23490
|
||||
}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
// OpenAI-style: cache fields are nested inside
|
||||
// input_tokens_details.
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{
|
||||
"input_tokens_details": {"cached_tokens": 11904}
|
||||
}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Second Anthropic row to verify summing.
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{
|
||||
"cache_creation_input_tokens": 500,
|
||||
"cache_read_input_tokens": 10000
|
||||
}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenMetadata(tokens)
|
||||
// Anthropic fields are summed across two rows.
|
||||
require.Equal(t, int64(500), result["cache_creation_input_tokens"])
|
||||
require.Equal(t, int64(33490), result["cache_read_input_tokens"])
|
||||
// OpenAI nested field is flattened with dot notation.
|
||||
require.Equal(t, int64(11904), result["input_tokens_details.cached_tokens"])
|
||||
require.Len(t, result, 3)
|
||||
})
|
||||
|
||||
t.Run("skips_string_boolean_null_values", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"tokens":10,"name":"test","enabled":true,"nothing":null}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenMetadata(tokens)
|
||||
require.Equal(t, int64(10), result["tokens"])
|
||||
require.Len(t, result, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAggregateTokenUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_input", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := aggregateTokenUsage(nil)
|
||||
require.Equal(t, int64(0), result.InputTokens)
|
||||
require.Equal(t, int64(0), result.OutputTokens)
|
||||
require.Empty(t, result.Metadata)
|
||||
})
|
||||
|
||||
t.Run("sums_tokens_and_metadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"reasoning_tokens":20}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: uuid.New(),
|
||||
InputTokens: 200,
|
||||
OutputTokens: 75,
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"reasoning_tokens":30}`),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenUsage(tokens)
|
||||
require.Equal(t, int64(300), result.InputTokens)
|
||||
require.Equal(t, int64(125), result.OutputTokens)
|
||||
require.Equal(t, int64(50), result.Metadata["reasoning_tokens"])
|
||||
})
|
||||
|
||||
t.Run("handles_rows_without_metadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokens := []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
InputTokens: 500,
|
||||
OutputTokens: 200,
|
||||
Metadata: pqtype.NullRawMessage{Valid: false},
|
||||
},
|
||||
}
|
||||
|
||||
result := aggregateTokenUsage(tokens)
|
||||
require.Equal(t, int64(500), result.InputTokens)
|
||||
require.Equal(t, int64(200), result.OutputTokens)
|
||||
require.Empty(t, result.Metadata)
|
||||
})
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -513,6 +514,69 @@ func TestChatQueuedMessage_ParsesUserContentParts(t *testing.T) {
|
||||
require.Equal(t, "queued text", queued.Content[0].Text)
|
||||
}
|
||||
|
||||
func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Every field of database.Chat is set to a non-zero value so
|
||||
// that the reflection check below catches any field that
|
||||
// db2sdk.Chat forgets to populate. When someone adds a new
|
||||
// field to codersdk.Chat, this test will fail until the
|
||||
// converter is updated.
|
||||
now := dbtime.Now()
|
||||
input := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
ParentChatID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "all-fields-test",
|
||||
Status: database.ChatStatusRunning,
|
||||
LastError: sql.NullString{String: "boom", Valid: true},
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Archived: true,
|
||||
PinOrder: 1,
|
||||
MCPServerIDs: []uuid.UUID{uuid.New()},
|
||||
Labels: database.StringMap{"env": "prod"},
|
||||
LastInjectedContext: pqtype.NullRawMessage{
|
||||
// Use a context-file part to verify internal
|
||||
// fields are not present (they are stripped at
|
||||
// write time by chatd, not at read time).
|
||||
RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
// Only ChatID is needed here. This test checks that
|
||||
// Chat.DiffStatus is non-nil, not that every DiffStatus
|
||||
// field is populated — that would be a separate test for
|
||||
// the ChatDiffStatus converter.
|
||||
diffStatus := &database.ChatDiffStatus{
|
||||
ChatID: input.ID,
|
||||
}
|
||||
|
||||
got := db2sdk.Chat(input, diffStatus)
|
||||
|
||||
v := reflect.ValueOf(got)
|
||||
typ := v.Type()
|
||||
// HasUnread is populated by ChatRows (which joins the
|
||||
// read-cursor query), not by Chat, so it is expected
|
||||
// to remain zero here.
|
||||
skip := map[string]bool{"HasUnread": true}
|
||||
for i := range typ.NumField() {
|
||||
field := typ.Field(i)
|
||||
if skip[field.Name] {
|
||||
continue
|
||||
}
|
||||
require.False(t, v.Field(i).IsZero(),
|
||||
"codersdk.Chat field %q is zero-valued — db2sdk.Chat may not be populating it",
|
||||
field.Name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatQueuedMessage_MalformedContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1570,13 +1570,13 @@ func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UU
|
||||
return q.db.AllUserIDs(ctx, includeSystem)
|
||||
}
|
||||
|
||||
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ArchiveChatByID(ctx, id)
|
||||
}
|
||||
@@ -2578,6 +2578,18 @@ func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]dat
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
|
||||
// The include-default-system-prompt flag is a deployment-wide setting read
|
||||
// during chat creation by every authenticated user, so no RBAC policy
|
||||
// check is needed. We still verify that a valid actor exists in the
|
||||
// context to ensure this is never callable by an unauthenticated or
|
||||
// system-internal path without an explicit actor.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return false, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatIncludeDefaultSystemPrompt(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
// ChatMessages are authorized through their parent Chat.
|
||||
// We need to fetch the message first to get its chat_id.
|
||||
@@ -2602,6 +2614,14 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC
|
||||
return q.db.GetChatMessagesByChatID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesByChatIDAscPaginated(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
@@ -2674,6 +2694,18 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
return q.db.GetChatSystemPrompt(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatSystemPromptConfig(ctx context.Context) (database.GetChatSystemPromptConfigRow, error) {
|
||||
// The system prompt configuration is a deployment-wide setting read during
|
||||
// chat creation by every authenticated user, so no RBAC policy check is
|
||||
// needed. We still verify that a valid actor exists in the context to
|
||||
// ensure this is never callable by an unauthenticated or system-internal
|
||||
// path without an explicit actor.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return database.GetChatSystemPromptConfigRow{}, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatSystemPromptConfig(ctx)
|
||||
}
|
||||
|
||||
// GetChatTemplateAllowlist requires deployment-config read permission,
|
||||
// unlike the peer getters (GetChatDesktopEnabled, etc.) which only
|
||||
// check actor presence. The allowlist is admin-configuration that
|
||||
@@ -2716,7 +2748,7 @@ func (q *querier) GetChatWorkspaceTTL(ctx context.Context) (string, error) {
|
||||
return q.db.GetChatWorkspaceTTL(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.GetChatsRow, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
@@ -2724,6 +2756,10 @@ func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]
|
||||
return q.db.GetAuthorizedChats(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByWorkspaceIDs)(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
// Just like with the audit logs query, shortcut if the user is an owner.
|
||||
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
|
||||
@@ -2775,7 +2811,15 @@ func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (q *querier) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
// Any user who can read chat resources can read the default
|
||||
// model config, since model resolution is required to create
|
||||
// a chat. This avoids gating on ResourceDeploymentConfig
|
||||
// which regular members lack.
|
||||
act, ok := ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return database.ChatModelConfig{}, ErrNoActor
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(act.ID)); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetDefaultChatModelConfig(ctx)
|
||||
@@ -3921,6 +3965,13 @@ func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License,
|
||||
return q.db.GetUnexpiredLicenses(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserAISeatStates(ctx context.Context, userIDs []uuid.UUID) ([]uuid.UUID, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetUserAISeatStates(ctx, userIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserActivityInsights(ctx context.Context, arg database.GetUserActivityInsightsParams) ([]database.GetUserActivityInsightsRow, error) {
|
||||
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
|
||||
if err := q.authorizeContext(ctx, policy.ActionViewInsights, rbac.ResourceTemplate); err != nil {
|
||||
@@ -5313,6 +5364,14 @@ func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg datab
|
||||
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.ListAuthorizedAIBridgeClients(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -5328,6 +5387,13 @@ func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Contex
|
||||
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeModelThought, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListAIBridgeModelThoughtsByInterceptionIDs(ctx, interceptionIDs)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -5336,6 +5402,14 @@ func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBri
|
||||
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams) ([]database.ListAIBridgeSessionThreadsRow, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.ListAuthorizedAIBridgeSessionThreads(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -5473,6 +5547,17 @@ func (q *querier) PaginatedOrganizationMembers(ctx context.Context, arg database
|
||||
return q.db.PaginatedOrganizationMembers(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) PinChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.PinChatByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
@@ -5564,13 +5649,13 @@ func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return q.db.UnarchiveChatByID(ctx, id)
|
||||
}
|
||||
@@ -5598,6 +5683,17 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
|
||||
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UnpinChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UnpinChatByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UnsetDefaultChatModelConfigs(ctx context.Context) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
@@ -5619,6 +5715,18 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
|
||||
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
return q.db.UpdateChatBuildAgentBinding(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
@@ -5652,6 +5760,39 @@ func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateC
|
||||
return q.db.UpdateChatLabelsByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatLastInjectedContext(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatLastModelConfigByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLastReadMessageID(ctx context.Context, arg database.UpdateChatLastReadMessageIDParams) error {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpdateChatLastReadMessageID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
@@ -5686,6 +5827,17 @@ func (q *querier) UpdateChatModelConfig(ctx context.Context, arg database.Update
|
||||
return q.db.UpdateChatModelConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatPinOrder(ctx context.Context, arg database.UpdateChatPinOrderParams) error {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpdateChatPinOrder(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
@@ -5706,7 +5858,18 @@ func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatS
|
||||
return q.db.UpdateChatStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
func (q *querier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatStatusPreserveUpdatedAt(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
@@ -5715,15 +5878,7 @@ func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateCh
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace is manually implemented for chat tables and may not be
|
||||
// present on every wrapped store interface yet.
|
||||
chatWorkspaceUpdater, ok := q.db.(interface {
|
||||
UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error)
|
||||
})
|
||||
if !ok {
|
||||
return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store")
|
||||
}
|
||||
return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg)
|
||||
return q.db.UpdateChatWorkspaceBinding(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
@@ -6827,6 +6982,13 @@ func (q *querier) UpsertChatDiffStatusReference(ctx context.Context, arg databas
|
||||
return q.db.UpsertChatDiffStatusReference(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -7167,6 +7329,14 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams, _ rbac.PreparedAuthorized) ([]string, error) {
|
||||
// TODO: Delete this function, all ListAIBridgeClients should be
|
||||
// authorized. For now just call ListAIBridgeClients on the authz
|
||||
// querier. This cannot be deleted for now because it's included in
|
||||
// the database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeClients(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
}
|
||||
@@ -7175,6 +7345,10 @@ func (q *querier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg datab
|
||||
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
func (q *querier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionThreadsRow, error) {
|
||||
return q.db.ListAuthorizedAIBridgeSessionThreads(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.GetChatsRow, error) {
|
||||
return q.GetChats(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -392,13 +392,25 @@ func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("ArchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("UnarchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().PinChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("UnpinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UnpinChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("SoftDeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
@@ -449,6 +461,13 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
|
||||
}))
|
||||
s.Run("GetChatsByWorkspaceIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := []uuid.UUID{chatA.WorkspaceID.UUID, chatB.WorkspaceID.UUID}
|
||||
dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerChatParams{
|
||||
OwnerID: uuid.New(),
|
||||
@@ -573,6 +592,14 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatMessagesByChatIDAscPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
arg := database.GetChatMessagesByChatIDAscPaginatedParams{ChatID: chat.ID, AfterID: 0, LimitVal: 50}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesByChatIDAscPaginated(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatMessagesByChatIDDescPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
@@ -604,7 +631,7 @@ func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("GetDefaultChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetDefaultChatModelConfig(gomock.Any()).Return(config, nil).AnyTimes()
|
||||
check.Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
check.Asserts(rbac.ResourceChat.WithOwner(testActorID.String()), policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
@@ -631,13 +658,13 @@ func (s *MethodTestSuite) TestChats() {
|
||||
}))
|
||||
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.GetChatsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.GetChatsRow{}, nil).AnyTimes()
|
||||
// No asserts here because it re-routes through GetChats which uses SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
@@ -648,6 +675,17 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatQueuedMessages(gomock.Any(), chat.ID).Return(qms, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(qms)
|
||||
}))
|
||||
s.Run("GetChatIncludeDefaultSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatIncludeDefaultSystemPrompt(gomock.Any()).Return(true, nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetChatSystemPromptConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatSystemPromptConfig(gomock.Any()).Return(database.GetChatSystemPromptConfigRow{
|
||||
ChatSystemPrompt: "prompt",
|
||||
IncludeDefaultSystemPrompt: true,
|
||||
}, nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetChatSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
@@ -683,7 +721,9 @@ func (s *MethodTestSuite) TestChats() {
|
||||
check.Args(threshold).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(chats)
|
||||
}))
|
||||
s.Run("InsertChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatParams{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{OwnerID: arg.OwnerID})
|
||||
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
|
||||
@@ -759,6 +799,26 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatLabelsByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatLastModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatLastModelConfigByIDParams{
|
||||
ID: chat.ID,
|
||||
LastModelConfigID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatLastModelConfigByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatStatusPreserveUpdatedAt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatStatusPreserveUpdatedAtParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
@@ -809,6 +869,16 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
s.Run("UpdateChatPinOrder", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatPinOrderParams{
|
||||
ID: chat.ID,
|
||||
PinOrder: 2,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatPinOrder(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("UpdateChatStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatStatusParams{
|
||||
@@ -819,15 +889,29 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
s.Run("UpdateChatBuildAgentBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatWorkspaceParams{
|
||||
ID: chat.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
arg := database.UpdateChatBuildAgentBindingParams{
|
||||
ID: chat.ID,
|
||||
BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
}
|
||||
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
|
||||
}))
|
||||
s.Run("UpdateChatWorkspaceBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatWorkspaceBindingParams{
|
||||
ID: chat.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
}
|
||||
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatWorkspaceBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
|
||||
}))
|
||||
s.Run("UnsetDefaultChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
@@ -879,6 +963,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().BackoffChatDiffStatus(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("UpsertChatIncludeDefaultSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatIncludeDefaultSystemPrompt(gomock.Any(), false).Return(nil).AnyTimes()
|
||||
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
@@ -1118,6 +1206,29 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatMCPServerIDs(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatLastInjectedContext", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatLastInjectedContextParams{
|
||||
ID: chat.ID,
|
||||
LastInjectedContext: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`[{"type":"text","text":"test"}]`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatLastReadMessageID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatLastReadMessageIDParams{
|
||||
ID: chat.ID,
|
||||
LastReadMessageID: 42,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatLastReadMessageID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("UpdateMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
arg := database.UpdateMCPServerConfigParams{
|
||||
@@ -2159,6 +2270,14 @@ func (s *MethodTestSuite) TestUser() {
|
||||
dbm.EXPECT().GetQuotaConsumedForUser(gomock.Any(), arg).Return(int64(0), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionRead).Returns(int64(0))
|
||||
}))
|
||||
s.Run("GetUserAISeatStates", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
a := testutil.Fake(s.T(), faker, database.User{})
|
||||
b := testutil.Fake(s.T(), faker, database.User{})
|
||||
ids := []uuid.UUID{a.ID, b.ID}
|
||||
seatStates := []uuid.UUID{a.ID}
|
||||
dbm.EXPECT().GetUserAISeatStates(gomock.Any(), ids).Return(seatStates, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceUser, policy.ActionRead).Returns(seatStates)
|
||||
}))
|
||||
s.Run("GetUserByEmailOrUsername", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.GetUserByEmailOrUsernameParams{Email: u.Email}
|
||||
@@ -5438,6 +5557,20 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeClients", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeClientsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeClients(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAuthorizedAIBridgeClients", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeClientsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeClients(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeSessionsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes()
|
||||
@@ -5484,6 +5617,26 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeModelThoughtsByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeModelThoughtsByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeModelThought{}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeModelThought{})
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeSessionThreads", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeSessionThreadsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeSessionThreads(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionThreadsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAuthorizedAIBridgeSessionThreads", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeSessionThreadsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeSessionThreads(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionThreadsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intcID := uuid.UUID{1}
|
||||
params := database.UpdateAIBridgeInterceptionEndedParams{ID: intcID}
|
||||
|
||||
@@ -1591,6 +1591,7 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
APIKeyID: seed.APIKeyID,
|
||||
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
|
||||
Provider: takeFirst(seed.Provider, "provider"),
|
||||
ProviderName: takeFirst(seed.ProviderName, "provider-name"),
|
||||
Model: takeFirst(seed.Model, "model"),
|
||||
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
|
||||
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
|
||||
@@ -1663,6 +1664,17 @@ func AIBridgeToolUsage(t testing.TB, db database.Store, seed database.InsertAIBr
|
||||
return toolUsage
|
||||
}
|
||||
|
||||
func AIBridgeModelThought(t testing.TB, db database.Store, seed database.InsertAIBridgeModelThoughtParams) database.AIBridgeModelThought {
|
||||
thought, err := db.InsertAIBridgeModelThought(genCtx, database.InsertAIBridgeModelThoughtParams{
|
||||
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
|
||||
Content: takeFirst(seed.Content, ""),
|
||||
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
|
||||
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
|
||||
})
|
||||
require.NoError(t, err, "insert aibridge model thought")
|
||||
return thought
|
||||
}
|
||||
|
||||
func Task(t testing.TB, db database.Store, orig database.TaskTable) database.Task {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -160,12 +160,12 @@ func (m queryMetricsStore) AllUserIDs(ctx context.Context, includeSystem bool) (
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.ArchiveChatByID(ctx, id)
|
||||
r0, r1 := m.s.ArchiveChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("ArchiveChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ArchiveChatByID").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, arg database.ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) {
|
||||
@@ -1120,6 +1120,14 @@ func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUI
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatIncludeDefaultSystemPrompt(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatIncludeDefaultSystemPrompt").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatIncludeDefaultSystemPrompt").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageByID(ctx, id)
|
||||
@@ -1136,6 +1144,14 @@ func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatIDAscPaginated(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesByChatIDAscPaginated").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatIDAscPaginated").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatIDDescPaginated(ctx, arg)
|
||||
@@ -1208,6 +1224,14 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatSystemPromptConfig(ctx context.Context) (database.GetChatSystemPromptConfigRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatSystemPromptConfig(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatSystemPromptConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatSystemPromptConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatTemplateAllowlist(ctx)
|
||||
@@ -1248,7 +1272,7 @@ func (m queryMetricsStore) GetChatWorkspaceTTL(ctx context.Context) (string, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.GetChatsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChats").Observe(time.Since(start).Seconds())
|
||||
@@ -1256,6 +1280,14 @@ func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsPa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatsByWorkspaceIDs(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetChatsByWorkspaceIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByWorkspaceIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
|
||||
@@ -2448,6 +2480,14 @@ func (m queryMetricsStore) GetUnexpiredLicenses(ctx context.Context) ([]database
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserAISeatStates(ctx context.Context, userIds []uuid.UUID) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserAISeatStates(ctx, userIds)
|
||||
m.queryLatencies.WithLabelValues("GetUserAISeatStates").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserAISeatStates").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserActivityInsights(ctx context.Context, arg database.GetUserActivityInsightsParams) ([]database.GetUserActivityInsightsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserActivityInsights(ctx, arg)
|
||||
@@ -3712,6 +3752,14 @@ func (m queryMetricsStore) InsertWorkspaceResourceMetadata(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeClients(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeClients").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeClients").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeInterceptions(ctx, arg)
|
||||
@@ -3728,6 +3776,14 @@ func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx conte
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeModelThought, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeModelThoughtsByInterceptionIDs(ctx, interceptionIds)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeModelThoughtsByInterceptionIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeModelThoughtsByInterceptionIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeModels(ctx, arg)
|
||||
@@ -3736,6 +3792,14 @@ func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams) ([]database.ListAIBridgeSessionThreadsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeSessionThreads(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeSessionThreads").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeSessionThreads").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeSessions(ctx, arg)
|
||||
@@ -3872,6 +3936,14 @@ func (m queryMetricsStore) PaginatedOrganizationMembers(ctx context.Context, arg
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) PinChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.PinChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("PinChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "PinChatByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.PopNextQueuedMessage(ctx, chatID)
|
||||
@@ -3952,12 +4024,12 @@ func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXact
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.UnarchiveChatByID(ctx, id)
|
||||
r0, r1 := m.s.UnarchiveChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UnarchiveChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnarchiveChatByID").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnarchiveTemplateVersion(ctx context.Context, arg database.UnarchiveTemplateVersionParams) error {
|
||||
@@ -3976,6 +4048,14 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnpinChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UnpinChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UnpinChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnpinChatByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnsetDefaultChatModelConfigs(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UnsetDefaultChatModelConfigs(ctx)
|
||||
@@ -4000,6 +4080,14 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatBuildAgentBinding(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatBuildAgentBinding").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatBuildAgentBinding").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatByID(ctx, arg)
|
||||
@@ -4024,6 +4112,30 @@ func (m queryMetricsStore) UpdateChatLabelsByID(ctx context.Context, arg databas
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatLastInjectedContext(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatLastInjectedContext").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLastInjectedContext").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatLastModelConfigByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatLastModelConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLastModelConfigByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatLastReadMessageID(ctx context.Context, arg database.UpdateChatLastReadMessageIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateChatLastReadMessageID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatLastReadMessageID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLastReadMessageID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg)
|
||||
@@ -4048,6 +4160,14 @@ func (m queryMetricsStore) UpdateChatModelConfig(ctx context.Context, arg databa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatPinOrder(ctx context.Context, arg database.UpdateChatPinOrderParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateChatPinOrder(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatPinOrder").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatPinOrder").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatProvider(ctx, arg)
|
||||
@@ -4064,11 +4184,19 @@ func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.Up
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
func (m queryMetricsStore) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatWorkspace(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc()
|
||||
r0, r1 := m.s.UpdateChatStatusPreserveUpdatedAt(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatStatusPreserveUpdatedAt").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatStatusPreserveUpdatedAt").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatWorkspaceBinding(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatWorkspaceBinding").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspaceBinding").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4816,6 +4944,14 @@ func (m queryMetricsStore) UpsertChatDiffStatusReference(ctx context.Context, ar
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatIncludeDefaultSystemPrompt").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatIncludeDefaultSystemPrompt").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatSystemPrompt(ctx, value)
|
||||
@@ -5176,6 +5312,14 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAuthorizedAIBridgeClients(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeClients").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeClients").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
@@ -5192,7 +5336,15 @@ func (m queryMetricsStore) CountAuthorizedAIBridgeSessions(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionThreadsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAuthorizedAIBridgeSessionThreads(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeSessionThreads").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeSessionThreads").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.GetChatsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedChats").Observe(time.Since(start).Seconds())
|
||||
|
||||
@@ -148,11 +148,12 @@ func (mr *MockStoreMockRecorder) AllUserIDs(ctx, includeSystem any) *gomock.Call
|
||||
}
|
||||
|
||||
// ArchiveChatByID mocks base method.
|
||||
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ArchiveChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ArchiveChatByID indicates an expected call of ArchiveChatByID.
|
||||
@@ -1804,10 +1805,10 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared
|
||||
}
|
||||
|
||||
// GetAuthorizedChats mocks base method.
|
||||
func (m *MockStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
func (m *MockStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.GetChatsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthorizedChats", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret0, _ := ret[0].([]database.GetChatsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -2058,6 +2059,21 @@ func (mr *MockStoreMockRecorder) GetChatFilesByIDs(ctx, ids any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFilesByIDs", reflect.TypeOf((*MockStore)(nil).GetChatFilesByIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetChatIncludeDefaultSystemPrompt mocks base method.
|
||||
func (m *MockStore) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatIncludeDefaultSystemPrompt", ctx)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatIncludeDefaultSystemPrompt indicates an expected call of GetChatIncludeDefaultSystemPrompt.
|
||||
func (mr *MockStoreMockRecorder) GetChatIncludeDefaultSystemPrompt(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatIncludeDefaultSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatIncludeDefaultSystemPrompt), ctx)
|
||||
}
|
||||
|
||||
// GetChatMessageByID mocks base method.
|
||||
func (m *MockStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2088,6 +2104,21 @@ func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDAscPaginated mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesByChatIDAscPaginated", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDAscPaginated indicates an expected call of GetChatMessagesByChatIDAscPaginated.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesByChatIDAscPaginated(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatIDAscPaginated", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatIDAscPaginated), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDDescPaginated mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2223,6 +2254,21 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx)
|
||||
}
|
||||
|
||||
// GetChatSystemPromptConfig mocks base method.
|
||||
func (m *MockStore) GetChatSystemPromptConfig(ctx context.Context) (database.GetChatSystemPromptConfigRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatSystemPromptConfig", ctx)
|
||||
ret0, _ := ret[0].(database.GetChatSystemPromptConfigRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatSystemPromptConfig indicates an expected call of GetChatSystemPromptConfig.
|
||||
func (mr *MockStoreMockRecorder) GetChatSystemPromptConfig(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPromptConfig", reflect.TypeOf((*MockStore)(nil).GetChatSystemPromptConfig), ctx)
|
||||
}
|
||||
|
||||
// GetChatTemplateAllowlist mocks base method.
|
||||
func (m *MockStore) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2299,10 +2345,10 @@ func (mr *MockStoreMockRecorder) GetChatWorkspaceTTL(ctx any) *gomock.Call {
|
||||
}
|
||||
|
||||
// GetChats mocks base method.
|
||||
func (m *MockStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
func (m *MockStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.GetChatsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChats", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret0, _ := ret[0].([]database.GetChatsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -2313,6 +2359,21 @@ func (mr *MockStoreMockRecorder) GetChats(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChats", reflect.TypeOf((*MockStore)(nil).GetChats), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatsByWorkspaceIDs mocks base method.
|
||||
func (m *MockStore) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatsByWorkspaceIDs", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatsByWorkspaceIDs indicates an expected call of GetChatsByWorkspaceIDs.
|
||||
func (mr *MockStoreMockRecorder) GetChatsByWorkspaceIDs(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByWorkspaceIDs", reflect.TypeOf((*MockStore)(nil).GetChatsByWorkspaceIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4578,6 +4639,21 @@ func (mr *MockStoreMockRecorder) GetUnexpiredLicenses(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUnexpiredLicenses", reflect.TypeOf((*MockStore)(nil).GetUnexpiredLicenses), ctx)
|
||||
}
|
||||
|
||||
// GetUserAISeatStates mocks base method.
|
||||
func (m *MockStore) GetUserAISeatStates(ctx context.Context, userIds []uuid.UUID) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserAISeatStates", ctx, userIds)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserAISeatStates indicates an expected call of GetUserAISeatStates.
|
||||
func (mr *MockStoreMockRecorder) GetUserAISeatStates(ctx, userIds any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserAISeatStates", reflect.TypeOf((*MockStore)(nil).GetUserAISeatStates), ctx, userIds)
|
||||
}
|
||||
|
||||
// GetUserActivityInsights mocks base method.
|
||||
func (m *MockStore) GetUserActivityInsights(ctx context.Context, arg database.GetUserActivityInsightsParams) ([]database.GetUserActivityInsightsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6947,6 +7023,21 @@ func (mr *MockStoreMockRecorder) InsertWorkspaceResourceMetadata(ctx, arg any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertWorkspaceResourceMetadata", reflect.TypeOf((*MockStore)(nil).InsertWorkspaceResourceMetadata), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeClients mocks base method.
|
||||
func (m *MockStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeClients", ctx, arg)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAIBridgeClients indicates an expected call of ListAIBridgeClients.
|
||||
func (mr *MockStoreMockRecorder) ListAIBridgeClients(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeClients", reflect.TypeOf((*MockStore)(nil).ListAIBridgeClients), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6977,6 +7068,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeInterceptionsTelemetrySummaries(ctx
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptionsTelemetrySummaries", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptionsTelemetrySummaries), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeModelThoughtsByInterceptionIDs mocks base method.
|
||||
func (m *MockStore) ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeModelThought, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeModelThoughtsByInterceptionIDs", ctx, interceptionIds)
|
||||
ret0, _ := ret[0].([]database.AIBridgeModelThought)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAIBridgeModelThoughtsByInterceptionIDs indicates an expected call of ListAIBridgeModelThoughtsByInterceptionIDs.
|
||||
func (mr *MockStoreMockRecorder) ListAIBridgeModelThoughtsByInterceptionIDs(ctx, interceptionIds any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModelThoughtsByInterceptionIDs", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModelThoughtsByInterceptionIDs), ctx, interceptionIds)
|
||||
}
|
||||
|
||||
// ListAIBridgeModels mocks base method.
|
||||
func (m *MockStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6992,6 +7098,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeSessionThreads mocks base method.
|
||||
func (m *MockStore) ListAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams) ([]database.ListAIBridgeSessionThreadsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeSessionThreads", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeSessionThreadsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAIBridgeSessionThreads indicates an expected call of ListAIBridgeSessionThreads.
|
||||
func (mr *MockStoreMockRecorder) ListAIBridgeSessionThreads(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeSessionThreads", reflect.TypeOf((*MockStore)(nil).ListAIBridgeSessionThreads), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7052,6 +7173,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeUserPromptsByInterceptionIDs(ctx, i
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeUserPromptsByInterceptionIDs", reflect.TypeOf((*MockStore)(nil).ListAIBridgeUserPromptsByInterceptionIDs), ctx, interceptionIds)
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeClients mocks base method.
|
||||
func (m *MockStore) ListAuthorizedAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeClients", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeClients indicates an expected call of ListAuthorizedAIBridgeClients.
|
||||
func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeClients(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeClients", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeClients), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7082,6 +7218,21 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeSessionThreads mocks base method.
|
||||
func (m *MockStore) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionThreadsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeSessionThreads", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeSessionThreadsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeSessionThreads indicates an expected call of ListAuthorizedAIBridgeSessionThreads.
|
||||
func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeSessionThreads(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeSessionThreads", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeSessionThreads), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7306,6 +7457,20 @@ func (mr *MockStoreMockRecorder) PaginatedOrganizationMembers(ctx, arg any) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PaginatedOrganizationMembers", reflect.TypeOf((*MockStore)(nil).PaginatedOrganizationMembers), ctx, arg)
|
||||
}
|
||||
|
||||
// PinChatByID mocks base method.
|
||||
func (m *MockStore) PinChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PinChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PinChatByID indicates an expected call of PinChatByID.
|
||||
func (mr *MockStoreMockRecorder) PinChatByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PinChatByID", reflect.TypeOf((*MockStore)(nil).PinChatByID), ctx, id)
|
||||
}
|
||||
|
||||
// Ping mocks base method.
|
||||
func (m *MockStore) Ping(ctx context.Context) (time.Duration, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7468,11 +7633,12 @@ func (mr *MockStoreMockRecorder) TryAcquireLock(ctx, pgTryAdvisoryXactLock any)
|
||||
}
|
||||
|
||||
// UnarchiveChatByID mocks base method.
|
||||
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnarchiveChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UnarchiveChatByID indicates an expected call of UnarchiveChatByID.
|
||||
@@ -7509,6 +7675,20 @@ func (mr *MockStoreMockRecorder) UnfavoriteWorkspace(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnfavoriteWorkspace", reflect.TypeOf((*MockStore)(nil).UnfavoriteWorkspace), ctx, id)
|
||||
}
|
||||
|
||||
// UnpinChatByID mocks base method.
|
||||
func (m *MockStore) UnpinChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnpinChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnpinChatByID indicates an expected call of UnpinChatByID.
|
||||
func (mr *MockStoreMockRecorder) UnpinChatByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpinChatByID", reflect.TypeOf((*MockStore)(nil).UnpinChatByID), ctx, id)
|
||||
}
|
||||
|
||||
// UnsetDefaultChatModelConfigs mocks base method.
|
||||
func (m *MockStore) UnsetDefaultChatModelConfigs(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7552,6 +7732,21 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatBuildAgentBinding mocks base method.
|
||||
func (m *MockStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatBuildAgentBinding", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatBuildAgentBinding indicates an expected call of UpdateChatBuildAgentBinding.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatBuildAgentBinding(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatBuildAgentBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatBuildAgentBinding), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatByID mocks base method.
|
||||
func (m *MockStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7597,6 +7792,50 @@ func (mr *MockStoreMockRecorder) UpdateChatLabelsByID(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLabelsByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLabelsByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLastInjectedContext mocks base method.
|
||||
func (m *MockStore) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatLastInjectedContext", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatLastInjectedContext indicates an expected call of UpdateChatLastInjectedContext.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatLastInjectedContext(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLastInjectedContext", reflect.TypeOf((*MockStore)(nil).UpdateChatLastInjectedContext), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLastModelConfigByID mocks base method.
|
||||
func (m *MockStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatLastModelConfigByID", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatLastModelConfigByID indicates an expected call of UpdateChatLastModelConfigByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatLastModelConfigByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLastModelConfigByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLastModelConfigByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLastReadMessageID mocks base method.
|
||||
func (m *MockStore) UpdateChatLastReadMessageID(ctx context.Context, arg database.UpdateChatLastReadMessageIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatLastReadMessageID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateChatLastReadMessageID indicates an expected call of UpdateChatLastReadMessageID.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatLastReadMessageID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLastReadMessageID", reflect.TypeOf((*MockStore)(nil).UpdateChatLastReadMessageID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatMCPServerIDs mocks base method.
|
||||
func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7642,6 +7881,20 @@ func (mr *MockStoreMockRecorder) UpdateChatModelConfig(ctx, arg any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatModelConfig", reflect.TypeOf((*MockStore)(nil).UpdateChatModelConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatPinOrder mocks base method.
|
||||
func (m *MockStore) UpdateChatPinOrder(ctx context.Context, arg database.UpdateChatPinOrderParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatPinOrder", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateChatPinOrder indicates an expected call of UpdateChatPinOrder.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatPinOrder(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatPinOrder", reflect.TypeOf((*MockStore)(nil).UpdateChatPinOrder), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatProvider mocks base method.
|
||||
func (m *MockStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7672,19 +7925,34 @@ func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace mocks base method.
|
||||
func (m *MockStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
// UpdateChatStatusPreserveUpdatedAt mocks base method.
|
||||
func (m *MockStore) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatWorkspace", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "UpdateChatStatusPreserveUpdatedAt", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace indicates an expected call of UpdateChatWorkspace.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatWorkspace(ctx, arg any) *gomock.Call {
|
||||
// UpdateChatStatusPreserveUpdatedAt indicates an expected call of UpdateChatStatusPreserveUpdatedAt.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatStatusPreserveUpdatedAt(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspace", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspace), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatusPreserveUpdatedAt", reflect.TypeOf((*MockStore)(nil).UpdateChatStatusPreserveUpdatedAt), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatWorkspaceBinding mocks base method.
|
||||
func (m *MockStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatWorkspaceBinding", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatWorkspaceBinding indicates an expected call of UpdateChatWorkspaceBinding.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatWorkspaceBinding(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspaceBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspaceBinding), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateCryptoKeyDeletesAt mocks base method.
|
||||
@@ -9029,6 +9297,20 @@ func (mr *MockStoreMockRecorder) UpsertChatDiffStatusReference(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatusReference", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatusReference), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatIncludeDefaultSystemPrompt mocks base method.
|
||||
func (m *MockStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatIncludeDefaultSystemPrompt", ctx, includeDefaultSystemPrompt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatIncludeDefaultSystemPrompt indicates an expected call of UpsertChatIncludeDefaultSystemPrompt.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatIncludeDefaultSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatIncludeDefaultSystemPrompt), ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
// UpsertChatSystemPrompt mocks base method.
|
||||
func (m *MockStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+17
-2
@@ -1100,7 +1100,8 @@ CREATE TABLE aibridge_interceptions (
|
||||
thread_parent_id uuid,
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256),
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL,
|
||||
provider_name text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1115,6 +1116,8 @@ COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID su
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception''s own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
@@ -1399,7 +1402,12 @@ CREATE TABLE chats (
|
||||
last_error text,
|
||||
mode chat_mode,
|
||||
mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL,
|
||||
labels jsonb DEFAULT '{}'::jsonb NOT NULL
|
||||
labels jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
build_id uuid,
|
||||
agent_id uuid,
|
||||
pin_order integer DEFAULT 0 NOT NULL,
|
||||
last_read_message_id bigint,
|
||||
last_injected_context jsonb
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -1703,6 +1711,7 @@ CREATE TABLE mcp_server_configs (
|
||||
updated_by uuid,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
model_intent boolean DEFAULT false NOT NULL,
|
||||
CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text]))),
|
||||
CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))),
|
||||
CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text])))
|
||||
@@ -4033,6 +4042,12 @@ ALTER TABLE ONLY chat_providers
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ const (
|
||||
ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsLastModelConfigID ForeignKeyConstraint = "chats_last_model_config_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
|
||||
ForeignKeyChatsOwnerID ForeignKeyConstraint = "chats_owner_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsParentChatID ForeignKeyConstraint = "chats_parent_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE chats
|
||||
DROP COLUMN IF EXISTS build_id,
|
||||
DROP COLUMN IF EXISTS agent_id;
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE chats
|
||||
ADD COLUMN build_id UUID REFERENCES workspace_builds(id) ON DELETE SET NULL,
|
||||
ADD COLUMN agent_id UUID REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chats DROP COLUMN pin_order;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chats ADD COLUMN pin_order integer DEFAULT 0 NOT NULL;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE mcp_server_configs DROP COLUMN model_intent;
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE mcp_server_configs
|
||||
ADD COLUMN model_intent BOOLEAN NOT NULL DEFAULT false;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chats DROP COLUMN last_read_message_id;
|
||||
@@ -0,0 +1,9 @@
|
||||
ALTER TABLE chats ADD COLUMN last_read_message_id BIGINT;
|
||||
|
||||
-- Backfill existing chats so they don't appear unread after deploy.
|
||||
-- The has_unread query uses COALESCE(last_read_message_id, 0), so
|
||||
-- leaving this NULL would mark every existing chat as unread.
|
||||
UPDATE chats SET last_read_message_id = (
|
||||
SELECT MAX(cm.id) FROM chat_messages cm
|
||||
WHERE cm.chat_id = chats.id AND cm.role = 'assistant' AND cm.deleted = false
|
||||
);
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chats DROP COLUMN last_injected_context;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chats ADD COLUMN last_injected_context JSONB;
|
||||
@@ -0,0 +1,4 @@
|
||||
-- Remove 'agents-access' from all users who have it.
|
||||
UPDATE users
|
||||
SET rbac_roles = array_remove(rbac_roles, 'agents-access')
|
||||
WHERE 'agents-access' = ANY(rbac_roles);
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Grant 'agents-access' to every user who has ever created a chat.
|
||||
UPDATE users
|
||||
SET rbac_roles = array_append(rbac_roles, 'agents-access')
|
||||
WHERE id IN (SELECT DISTINCT owner_id FROM chats)
|
||||
AND NOT ('agents-access' = ANY(rbac_roles));
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE aibridge_interceptions DROP COLUMN provider_name;
|
||||
@@ -0,0 +1,6 @@
|
||||
ALTER TABLE aibridge_interceptions ADD COLUMN provider_name TEXT NOT NULL DEFAULT '';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
|
||||
|
||||
-- Backfill existing records with the provider type as the provider name.
|
||||
UPDATE aibridge_interceptions SET provider_name = provider WHERE provider_name = '';
|
||||
@@ -877,3 +877,149 @@ func TestMigration000387MigrateTaskWorkspaces(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, antCount, "antagonist workspaces (deleted and regular) should not be migrated")
|
||||
}
|
||||
|
||||
func TestMigration000457ChatAccessRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const migrationVersion = 457
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
|
||||
// Migrate up to the migration before the one that grants
|
||||
// agents-access roles.
|
||||
next, err := migrations.Stepper(sqlDB)
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
version, more, err := next()
|
||||
require.NoError(t, err)
|
||||
if !more {
|
||||
t.Fatalf("migration %d not found", migrationVersion)
|
||||
}
|
||||
if version == migrationVersion-1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
|
||||
// Define test users.
|
||||
userWithChat := uuid.New() // Has a chat, no agents-access role.
|
||||
userAlreadyHasRole := uuid.New() // Has a chat and already has agents-access.
|
||||
userNoChat := uuid.New() // No chat at all.
|
||||
userWithChatAndRoles := uuid.New() // Has a chat and other existing roles.
|
||||
|
||||
now := time.Now().UTC().Truncate(time.Microsecond)
|
||||
|
||||
// We need a chat_provider and chat_model_config for the chats FK.
|
||||
providerID := uuid.New()
|
||||
modelConfigID := uuid.New()
|
||||
|
||||
tx, err := sqlDB.BeginTx(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
fixtures := []struct {
|
||||
query string
|
||||
args []any
|
||||
}{
|
||||
// Insert test users with varying rbac_roles.
|
||||
{
|
||||
`INSERT INTO users (id, username, email, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{userWithChat, "user-with-chat", "chat@test.com", []byte{}, now, now, "active", pq.StringArray{}, "password"},
|
||||
},
|
||||
{
|
||||
`INSERT INTO users (id, username, email, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{userAlreadyHasRole, "user-already-has-role", "already@test.com", []byte{}, now, now, "active", pq.StringArray{"agents-access"}, "password"},
|
||||
},
|
||||
{
|
||||
`INSERT INTO users (id, username, email, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{userNoChat, "user-no-chat", "nochat@test.com", []byte{}, now, now, "active", pq.StringArray{}, "password"},
|
||||
},
|
||||
{
|
||||
`INSERT INTO users (id, username, email, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{userWithChatAndRoles, "user-with-roles", "roles@test.com", []byte{}, now, now, "active", pq.StringArray{"template-admin"}, "password"},
|
||||
},
|
||||
// Insert a chat provider and model config for the chats FK.
|
||||
{
|
||||
`INSERT INTO chat_providers (id, provider, display_name, api_key, enabled, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
||||
[]any{providerID, "openai", "OpenAI", "", true, now, now},
|
||||
},
|
||||
{
|
||||
`INSERT INTO chat_model_configs (id, provider, model, display_name, enabled, context_limit, compression_threshold, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
[]any{modelConfigID, "openai", "gpt-4", "GPT 4", true, 100000, 70, now, now},
|
||||
},
|
||||
// Insert chats for users A, B, and D (not C).
|
||||
{
|
||||
`INSERT INTO chats (id, owner_id, last_model_config_id, title, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
[]any{uuid.New(), userWithChat, modelConfigID, "Chat A", now, now},
|
||||
},
|
||||
{
|
||||
`INSERT INTO chats (id, owner_id, last_model_config_id, title, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
[]any{uuid.New(), userAlreadyHasRole, modelConfigID, "Chat B", now, now},
|
||||
},
|
||||
{
|
||||
`INSERT INTO chats (id, owner_id, last_model_config_id, title, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
[]any{uuid.New(), userWithChatAndRoles, modelConfigID, "Chat D", now, now},
|
||||
},
|
||||
}
|
||||
|
||||
for i, f := range fixtures {
|
||||
_, err := tx.ExecContext(ctx, f.query, f.args...)
|
||||
require.NoError(t, err, "fixture %d", i)
|
||||
}
|
||||
require.NoError(t, tx.Commit())
|
||||
|
||||
// Run the migration.
|
||||
version, _, err := next()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, migrationVersion, version)
|
||||
|
||||
// Helper to get rbac_roles for a user.
|
||||
getRoles := func(t *testing.T, userID uuid.UUID) []string {
|
||||
t.Helper()
|
||||
var roles pq.StringArray
|
||||
err := sqlDB.QueryRowContext(ctx,
|
||||
"SELECT rbac_roles FROM users WHERE id = $1", userID,
|
||||
).Scan(&roles)
|
||||
require.NoError(t, err)
|
||||
return roles
|
||||
}
|
||||
|
||||
// Verify: user with chat gets agents-access.
|
||||
roles := getRoles(t, userWithChat)
|
||||
require.Contains(t, roles, "agents-access",
|
||||
"user with chat should get agents-access")
|
||||
|
||||
// Verify: user who already had agents-access has no duplicate.
|
||||
roles = getRoles(t, userAlreadyHasRole)
|
||||
count := 0
|
||||
for _, r := range roles {
|
||||
if r == "agents-access" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, count,
|
||||
"user who already had agents-access should not get a duplicate")
|
||||
|
||||
// Verify: user without chat does NOT get agents-access.
|
||||
roles = getRoles(t, userNoChat)
|
||||
require.NotContains(t, roles, "agents-access",
|
||||
"user without chat should not get agents-access")
|
||||
|
||||
// Verify: user with chat and existing roles gets agents-access
|
||||
// appended while preserving existing roles.
|
||||
roles = getRoles(t, userWithChatAndRoles)
|
||||
require.Contains(t, roles, "agents-access",
|
||||
"user with chat and other roles should get agents-access")
|
||||
require.Contains(t, roles, "template-admin",
|
||||
"existing roles should be preserved")
|
||||
}
|
||||
|
||||
@@ -178,6 +178,10 @@ func (c Chat) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String())
|
||||
}
|
||||
|
||||
func (r GetChatsRow) RBACObject() rbac.Object {
|
||||
return r.Chat.RBACObject()
|
||||
}
|
||||
|
||||
func (c ChatFile) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
|
||||
}
|
||||
|
||||
+119
-26
@@ -741,10 +741,10 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
}
|
||||
|
||||
type chatQuerier interface {
|
||||
GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error)
|
||||
GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]GetChatsRow, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error) {
|
||||
func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]GetChatsRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigChats())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
@@ -769,29 +769,34 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Chat
|
||||
var items []GetChatsRow
|
||||
for rows.Next() {
|
||||
var i Chat
|
||||
var i GetChatsRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
); err != nil {
|
||||
&i.Chat.ID,
|
||||
&i.Chat.OwnerID,
|
||||
&i.Chat.WorkspaceID,
|
||||
&i.Chat.Title,
|
||||
&i.Chat.Status,
|
||||
&i.Chat.WorkerID,
|
||||
&i.Chat.StartedAt,
|
||||
&i.Chat.HeartbeatAt,
|
||||
&i.Chat.CreatedAt,
|
||||
&i.Chat.UpdatedAt,
|
||||
&i.Chat.ParentChatID,
|
||||
&i.Chat.RootChatID,
|
||||
&i.Chat.LastModelConfigID,
|
||||
&i.Chat.Archived,
|
||||
&i.Chat.LastError,
|
||||
&i.Chat.Mode,
|
||||
pq.Array(&i.Chat.MCPServerIDs),
|
||||
&i.Chat.Labels,
|
||||
&i.Chat.BuildID,
|
||||
&i.Chat.AgentID,
|
||||
&i.Chat.PinOrder,
|
||||
&i.Chat.LastReadMessageID,
|
||||
&i.Chat.LastInjectedContext,
|
||||
&i.HasUnread); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
@@ -809,8 +814,10 @@ type aibridgeQuerier interface {
|
||||
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
|
||||
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error)
|
||||
ListAuthorizedAIBridgeClients(ctx context.Context, arg ListAIBridgeClientsParams, prepared rbac.PreparedAuthorized) ([]string, error)
|
||||
ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error)
|
||||
CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionThreadsRow, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) {
|
||||
@@ -858,6 +865,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -945,6 +953,35 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListA
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeClients(ctx context.Context, arg ListAIBridgeClientsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
filtered, err := insertAuthorizedFilter(listAIBridgeClients, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: ListAIBridgeClients :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query, arg.Client, arg.Offset, arg.Limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []string
|
||||
for rows.Next() {
|
||||
var client string
|
||||
if err := rows.Scan(&client); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, client)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
@@ -960,8 +997,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis
|
||||
query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessions :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.AfterSessionID,
|
||||
arg.Offset,
|
||||
arg.Limit,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
@@ -969,6 +1004,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
arg.Offset,
|
||||
arg.Limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1048,11 +1085,67 @@ func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg Co
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionThreadsRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
filtered, err := insertAuthorizedFilter(listAIBridgeSessionThreads, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessionThreads :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.SessionID,
|
||||
arg.AfterID,
|
||||
arg.BeforeID,
|
||||
arg.Limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAIBridgeSessionThreadsRow
|
||||
for rows.Next() {
|
||||
var i ListAIBridgeSessionThreadsRow
|
||||
if err := rows.Scan(
|
||||
&i.ThreadID,
|
||||
&i.AIBridgeInterception.ID,
|
||||
&i.AIBridgeInterception.InitiatorID,
|
||||
&i.AIBridgeInterception.Provider,
|
||||
&i.AIBridgeInterception.Model,
|
||||
&i.AIBridgeInterception.StartedAt,
|
||||
&i.AIBridgeInterception.Metadata,
|
||||
&i.AIBridgeInterception.EndedAt,
|
||||
&i.AIBridgeInterception.APIKeyID,
|
||||
&i.AIBridgeInterception.Client,
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||
if !strings.Contains(query, authorizedQueryPlaceholder) {
|
||||
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
|
||||
}
|
||||
filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1)
|
||||
filtered := strings.ReplaceAll(query, authorizedQueryPlaceholder, replaceWith)
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
|
||||
+26
-18
@@ -4038,6 +4038,8 @@ type AIBridgeInterception struct {
|
||||
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
|
||||
// Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception's own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
// The provider instance name which may differ from provider when multiple instances of the same provider type exist.
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
@@ -4153,24 +4155,29 @@ type BoundaryUsageStat struct {
|
||||
}
|
||||
|
||||
type Chat struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
Archived bool `db:"archived" json:"archived"`
|
||||
LastError sql.NullString `db:"last_error" json:"last_error"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
Labels StringMap `db:"labels" json:"labels"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
Archived bool `db:"archived" json:"archived"`
|
||||
LastError sql.NullString `db:"last_error" json:"last_error"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
Labels StringMap `db:"labels" json:"labels"`
|
||||
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
PinOrder int32 `db:"pin_order" json:"pin_order"`
|
||||
LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"`
|
||||
LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
@@ -4485,6 +4492,7 @@ type MCPServerConfig struct {
|
||||
UpdatedBy uuid.NullUUID `db:"updated_by" json:"updated_by"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ModelIntent bool `db:"model_intent" json:"model_intent"`
|
||||
}
|
||||
|
||||
type MCPServerUserToken struct {
|
||||
|
||||
@@ -54,7 +54,7 @@ type sqlcQuerier interface {
|
||||
ActivityBumpWorkspace(ctx context.Context, arg ActivityBumpWorkspaceParams) error
|
||||
// AllUserIDs returns all UserIDs regardless of user status or deletion.
|
||||
AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error)
|
||||
ArchiveChatByID(ctx context.Context, id uuid.UUID) error
|
||||
ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// Archiving templates is a soft delete action, so is reversible.
|
||||
// Archiving prevents the version from being used and discovered
|
||||
// by listing.
|
||||
@@ -243,8 +243,14 @@ type sqlcQuerier interface {
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
|
||||
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
|
||||
// GetChatIncludeDefaultSystemPrompt preserves the legacy default
|
||||
// for deployments created before the explicit include-default toggle.
|
||||
// When the toggle is unset, a non-empty custom prompt implies false;
|
||||
// otherwise the setting defaults to true.
|
||||
GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
@@ -254,6 +260,12 @@ type sqlcQuerier interface {
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
GetChatSystemPrompt(ctx context.Context) (string, error)
|
||||
// GetChatSystemPromptConfig returns both chat system prompt settings in a
|
||||
// single read to avoid torn reads between separate site-config lookups.
|
||||
// The include-default fallback preserves the legacy behavior where a
|
||||
// non-empty custom prompt implied opting out before the explicit toggle
|
||||
// existed.
|
||||
GetChatSystemPromptConfig(ctx context.Context) (GetChatSystemPromptConfigRow, error)
|
||||
// GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
|
||||
// Returns an empty string when no allowlist has been configured (all templates allowed).
|
||||
GetChatTemplateAllowlist(ctx context.Context) (string, error)
|
||||
@@ -263,7 +275,8 @@ type sqlcQuerier interface {
|
||||
// Returns the global TTL for chat workspaces as a Go duration string.
|
||||
// Returns "0s" (disabled) when no value has been configured.
|
||||
GetChatWorkspaceTTL(ctx context.Context) (string, error)
|
||||
GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, error)
|
||||
GetChats(ctx context.Context, arg GetChatsParams) ([]GetChatsRow, error)
|
||||
GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]Chat, error)
|
||||
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
|
||||
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
|
||||
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
|
||||
@@ -548,6 +561,10 @@ type sqlcQuerier interface {
|
||||
// inclusive.
|
||||
GetTotalUsageDCManagedAgentsV1(ctx context.Context, arg GetTotalUsageDCManagedAgentsV1Params) (int64, error)
|
||||
GetUnexpiredLicenses(ctx context.Context) ([]License, error)
|
||||
// Returns user IDs from the provided list that are consuming an AI seat.
|
||||
// Filters to active, non-deleted, non-system users to match the canonical
|
||||
// seat count query (GetActiveAISeatCount).
|
||||
GetUserAISeatStates(ctx context.Context, userIds []uuid.UUID) ([]uuid.UUID, error)
|
||||
// GetUserActivityInsights returns the ranking with top active users.
|
||||
// The result can be filtered on template_ids, meaning only user data
|
||||
// from workspaces based on those templates will be included.
|
||||
@@ -758,14 +775,23 @@ type sqlcQuerier interface {
|
||||
InsertWorkspaceProxy(ctx context.Context, arg InsertWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
|
||||
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
|
||||
ListAIBridgeClients(ctx context.Context, arg ListAIBridgeClientsParams) ([]string, error)
|
||||
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error)
|
||||
// Finds all unique AI Bridge interception telemetry summaries combinations
|
||||
// (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
|
||||
ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeModelThought, error)
|
||||
ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error)
|
||||
// Returns all interceptions belonging to paginated threads within a session.
|
||||
// Threads are paginated by (started_at, thread_id) cursor.
|
||||
ListAIBridgeSessionThreads(ctx context.Context, arg ListAIBridgeSessionThreadsParams) ([]ListAIBridgeSessionThreadsRow, error)
|
||||
// Returns paginated sessions with aggregated metadata, token counts, and
|
||||
// the most recent user prompt. A "session" is a logical grouping of
|
||||
// interceptions that share the same session_id (set by the client).
|
||||
//
|
||||
// Pagination-first strategy: identify the page of sessions cheaply via a
|
||||
// single GROUP BY scan, then do expensive lateral joins (tokens, prompts,
|
||||
// first-interception metadata) only for the ~page-size result set.
|
||||
ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error)
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
@@ -789,6 +815,12 @@ type sqlcQuerier interface {
|
||||
// - Use both to get a specific org member row
|
||||
OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error)
|
||||
PaginatedOrganizationMembers(ctx context.Context, arg PaginatedOrganizationMembersParams) ([]PaginatedOrganizationMembersRow, error)
|
||||
// Under READ COMMITTED, concurrent pin operations for the same
|
||||
// owner may momentarily produce duplicate pin_order values because
|
||||
// each CTE snapshot does not see the other's writes. The next
|
||||
// pin/unpin/reorder operation's ROW_NUMBER() self-heals the
|
||||
// sequence, so this is acceptable.
|
||||
PinChatByID(ctx context.Context, id uuid.UUID) error
|
||||
PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (ChatQueuedMessage, error)
|
||||
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
@@ -812,24 +844,37 @@ type sqlcQuerier interface {
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
// released when the transaction ends.
|
||||
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) error
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
|
||||
UnpinChatByID(ctx context.Context, id uuid.UUID) error
|
||||
UnsetDefaultChatModelConfigs(ctx context.Context) error
|
||||
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
// Updates the cached injected context parts (AGENTS.md +
|
||||
// skills) on the chat row. Called only when context changes
|
||||
// (first workspace attach or agent change). updated_at is
|
||||
// intentionally not touched to avoid reordering the chat list.
|
||||
UpdateChatLastInjectedContext(ctx context.Context, arg UpdateChatLastInjectedContextParams) (Chat, error)
|
||||
UpdateChatLastModelConfigByID(ctx context.Context, arg UpdateChatLastModelConfigByIDParams) (Chat, error)
|
||||
// Updates the last read message ID for a chat. This is used to track
|
||||
// which messages the owner has seen, enabling unread indicators.
|
||||
UpdateChatLastReadMessageID(ctx context.Context, arg UpdateChatLastReadMessageIDParams) error
|
||||
UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error)
|
||||
UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error)
|
||||
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
|
||||
UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOrderParams) error
|
||||
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
|
||||
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
|
||||
UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error)
|
||||
UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error)
|
||||
UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error)
|
||||
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
|
||||
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
|
||||
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
@@ -936,6 +981,7 @@ type sqlcQuerier interface {
|
||||
UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error
|
||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
|
||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||
|
||||
+329
-12
@@ -1251,8 +1251,12 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
owner := dbgen.User(t, db, database.User{
|
||||
RBACRoles: []string{rbac.RoleOwner().String()},
|
||||
})
|
||||
member := dbgen.User(t, db, database.User{})
|
||||
secondMember := dbgen.User(t, db, database.User{})
|
||||
member := dbgen.User(t, db, database.User{
|
||||
RBACRoles: pq.StringArray{rbac.RoleAgentsAccess().String()},
|
||||
})
|
||||
secondMember := dbgen.User(t, db, database.User{
|
||||
RBACRoles: pq.StringArray{rbac.RoleAgentsAccess().String()},
|
||||
})
|
||||
|
||||
// Create FK dependencies: a chat provider and model config.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -1281,6 +1285,7 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
// Create 3 chats owned by owner.
|
||||
for i := range 3 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("owner chat %d", i+1),
|
||||
@@ -1291,6 +1296,7 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
// Create 2 chats owned by member.
|
||||
for i := range 2 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: member.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("member chat %d", i+1),
|
||||
@@ -1311,7 +1317,7 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberRows, 2)
|
||||
for _, row := range memberRows {
|
||||
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
|
||||
require.Equal(t, member.ID, row.Chat.OwnerID, "member should only see own chats")
|
||||
}
|
||||
|
||||
// Owner should see at least the 5 pre-created chats (site-wide
|
||||
@@ -1381,7 +1387,7 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberRows, 2)
|
||||
for _, row := range memberRows {
|
||||
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
|
||||
require.Equal(t, member.ID, row.Chat.OwnerID, "member should only see own chats")
|
||||
}
|
||||
|
||||
// As owner: should see at least the 5 pre-created chats.
|
||||
@@ -1407,9 +1413,12 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
|
||||
// Use a dedicated user for pagination to avoid interference
|
||||
// with the other parallel subtests.
|
||||
paginationUser := dbgen.User(t, db, database.User{})
|
||||
paginationUser := dbgen.User(t, db, database.User{
|
||||
RBACRoles: pq.StringArray{rbac.RoleAgentsAccess().String()},
|
||||
})
|
||||
for i := range 7 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: paginationUser.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("pagination chat %d", i+1),
|
||||
@@ -1429,13 +1438,13 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, page1, 2)
|
||||
for _, row := range page1 {
|
||||
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
|
||||
require.Equal(t, paginationUser.ID, row.Chat.OwnerID, "paginated results must belong to pagination user")
|
||||
}
|
||||
|
||||
// Fetch remaining pages and collect all chat IDs.
|
||||
allIDs := make(map[uuid.UUID]struct{})
|
||||
for _, row := range page1 {
|
||||
allIDs[row.ID] = struct{}{}
|
||||
allIDs[row.Chat.ID] = struct{}{}
|
||||
}
|
||||
offset := int32(2)
|
||||
for {
|
||||
@@ -1445,8 +1454,8 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
for _, row := range page {
|
||||
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
|
||||
allIDs[row.ID] = struct{}{}
|
||||
require.Equal(t, paginationUser.ID, row.Chat.OwnerID, "paginated results must belong to pagination user")
|
||||
allIDs[row.Chat.ID] = struct{}{}
|
||||
}
|
||||
if len(page) < 2 {
|
||||
break
|
||||
@@ -9466,6 +9475,7 @@ func TestInsertChatMessages(t *testing.T) {
|
||||
)
|
||||
|
||||
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelConfigA.ID,
|
||||
Title: "test-chat-" + uuid.NewString(),
|
||||
@@ -9635,6 +9645,7 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
||||
newChat := func(t *testing.T) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-chat-" + uuid.NewString(),
|
||||
@@ -10008,6 +10019,7 @@ func TestGetPRInsights(t *testing.T) {
|
||||
createChat := func(t *testing.T, store database.Store, userID, mcID uuid.UUID, title string) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := store.InsertChat(context.Background(), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: userID,
|
||||
LastModelConfigID: mcID,
|
||||
Title: title,
|
||||
@@ -10143,6 +10155,7 @@ func TestGetPRInsights(t *testing.T) {
|
||||
createChildChat := func(t *testing.T, store database.Store, userID, mcID, parentID, rootID uuid.UUID, title string) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := store.InsertChat(context.Background(), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: userID,
|
||||
LastModelConfigID: mcID,
|
||||
Title: title,
|
||||
@@ -10487,6 +10500,187 @@ func TestGetPRInsights(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatPinOrderQueries(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
setup := func(t *testing.T) (context.Context, database.Store, uuid.UUID, uuid.UUID) {
|
||||
t.Helper()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
|
||||
// Use background context for fixture setup so the
|
||||
// timed test context doesn't tick during DB init.
|
||||
bg := context.Background()
|
||||
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
return ctx, db, owner.ID, modelCfg.ID
|
||||
}
|
||||
|
||||
createChat := func(t *testing.T, ctx context.Context, db database.Store, ownerID, modelCfgID uuid.UUID, title string) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelCfgID,
|
||||
Title: title,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
requirePinOrders := func(t *testing.T, ctx context.Context, db database.Store, want map[uuid.UUID]int32) {
|
||||
t.Helper()
|
||||
|
||||
for chatID, wantPinOrder := range want {
|
||||
chat, err := db.GetChatByID(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, wantPinOrder, chat.PinOrder)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("PinChatByIDAppendsWithinOwner", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, db, ownerID, modelCfgID := setup(t)
|
||||
first := createChat(t, ctx, db, ownerID, modelCfgID, "first")
|
||||
second := createChat(t, ctx, db, ownerID, modelCfgID, "second")
|
||||
third := createChat(t, ctx, db, ownerID, modelCfgID, "third")
|
||||
|
||||
otherOwner := dbgen.User(t, db, database.User{})
|
||||
other := createChat(t, ctx, db, otherOwner.ID, modelCfgID, "other-owner")
|
||||
|
||||
require.NoError(t, db.PinChatByID(ctx, other.ID))
|
||||
require.NoError(t, db.PinChatByID(ctx, first.ID))
|
||||
require.NoError(t, db.PinChatByID(ctx, second.ID))
|
||||
require.NoError(t, db.PinChatByID(ctx, third.ID))
|
||||
|
||||
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
||||
first.ID: 1,
|
||||
second.ID: 2,
|
||||
third.ID: 3,
|
||||
other.ID: 1,
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("UpdateChatPinOrderShiftsNeighborsAndClamps", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, db, ownerID, modelCfgID := setup(t)
|
||||
first := createChat(t, ctx, db, ownerID, modelCfgID, "first")
|
||||
second := createChat(t, ctx, db, ownerID, modelCfgID, "second")
|
||||
third := createChat(t, ctx, db, ownerID, modelCfgID, "third")
|
||||
|
||||
for _, chat := range []database.Chat{first, second, third} {
|
||||
require.NoError(t, db.PinChatByID(ctx, chat.ID))
|
||||
}
|
||||
|
||||
require.NoError(t, db.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
||||
ID: third.ID,
|
||||
PinOrder: 1,
|
||||
}))
|
||||
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
||||
first.ID: 2,
|
||||
second.ID: 3,
|
||||
third.ID: 1,
|
||||
})
|
||||
|
||||
require.NoError(t, db.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
||||
ID: third.ID,
|
||||
PinOrder: 99,
|
||||
}))
|
||||
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
||||
first.ID: 1,
|
||||
second.ID: 2,
|
||||
third.ID: 3,
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("UnpinChatByIDCompactsPinnedChats", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, db, ownerID, modelCfgID := setup(t)
|
||||
first := createChat(t, ctx, db, ownerID, modelCfgID, "first")
|
||||
second := createChat(t, ctx, db, ownerID, modelCfgID, "second")
|
||||
third := createChat(t, ctx, db, ownerID, modelCfgID, "third")
|
||||
|
||||
for _, chat := range []database.Chat{first, second, third} {
|
||||
require.NoError(t, db.PinChatByID(ctx, chat.ID))
|
||||
}
|
||||
|
||||
require.NoError(t, db.UnpinChatByID(ctx, second.ID))
|
||||
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
||||
first.ID: 1,
|
||||
second.ID: 0,
|
||||
third.ID: 2,
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ArchiveClearsPinAndExcludesFromRanking", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, db, ownerID, modelCfgID := setup(t)
|
||||
first := createChat(t, ctx, db, ownerID, modelCfgID, "first")
|
||||
second := createChat(t, ctx, db, ownerID, modelCfgID, "second")
|
||||
third := createChat(t, ctx, db, ownerID, modelCfgID, "third")
|
||||
|
||||
for _, chat := range []database.Chat{first, second, third} {
|
||||
require.NoError(t, db.PinChatByID(ctx, chat.ID))
|
||||
}
|
||||
|
||||
// Archive the middle pin.
|
||||
_, err := db.ArchiveChatByID(ctx, second.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archived chat should have pin_order cleared. Remaining
|
||||
// pins keep their original positions; the next mutation
|
||||
// compacts via ROW_NUMBER().
|
||||
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
||||
first.ID: 1,
|
||||
second.ID: 0,
|
||||
third.ID: 3,
|
||||
})
|
||||
|
||||
// Reorder among remaining active pins — archived chat
|
||||
// should not interfere with position calculation.
|
||||
require.NoError(t, db.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
||||
ID: third.ID,
|
||||
PinOrder: 1,
|
||||
}))
|
||||
// After reorder, ROW_NUMBER() compacts the sequence.
|
||||
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
||||
first.ID: 2,
|
||||
second.ID: 0,
|
||||
third.ID: 1,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatLabels(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
@@ -10532,6 +10726,7 @@ func TestChatLabels(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "labeled-chat",
|
||||
@@ -10554,6 +10749,7 @@ func TestChatLabels(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "no-labels-chat",
|
||||
@@ -10569,6 +10765,7 @@ func TestChatLabels(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "update-labels-chat",
|
||||
@@ -10609,6 +10806,7 @@ func TestChatLabels(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "original-title",
|
||||
@@ -10645,6 +10843,7 @@ func TestChatLabels(t *testing.T) {
|
||||
labelsJSON, err := json.Marshal(tc.labels)
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: tc.title,
|
||||
@@ -10670,7 +10869,7 @@ func TestChatLabels(t *testing.T) {
|
||||
|
||||
titles := make([]string, 0, len(results))
|
||||
for _, c := range results {
|
||||
titles = append(titles, c.Title)
|
||||
titles = append(titles, c.Chat.Title)
|
||||
}
|
||||
require.Contains(t, titles, "filter-a")
|
||||
require.Contains(t, titles, "filter-b")
|
||||
@@ -10688,8 +10887,7 @@ func TestChatLabels(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
require.Equal(t, "filter-a", results[0].Title)
|
||||
|
||||
require.Equal(t, "filter-a", results[0].Chat.Title)
|
||||
// No filter — should return all chats for this owner.
|
||||
allChats, err := db.GetChats(ctx, database.GetChatsParams{
|
||||
OwnerID: owner.ID,
|
||||
@@ -10698,3 +10896,122 @@ func TestChatLabels(t *testing.T) {
|
||||
require.GreaterOrEqual(t, len(allChats), 3)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatHasUnread(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
dbgen.Organization(t, store, database.Organization{})
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model-" + uuid.NewString(),
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-chat-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
getHasUnread := func() bool {
|
||||
rows, err := store.GetChats(ctx, database.GetChatsParams{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
for _, row := range rows {
|
||||
if row.Chat.ID == chat.ID {
|
||||
return row.HasUnread
|
||||
}
|
||||
}
|
||||
t.Fatal("chat not found in GetChats result")
|
||||
return false
|
||||
}
|
||||
|
||||
// New chat with no messages: not unread.
|
||||
require.False(t, getHasUnread(), "new chat with no messages should not be unread")
|
||||
|
||||
// Helper to insert a single chat message.
|
||||
insertMsg := func(role database.ChatMessageRole, text string) {
|
||||
t.Helper()
|
||||
_, err := store.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID},
|
||||
ModelConfigID: []uuid.UUID{modelCfg.ID},
|
||||
Role: []database.ChatMessageRole{role},
|
||||
Content: []string{fmt.Sprintf(`[{"type":"text","text":%q}]`, text)},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Insert an assistant message: becomes unread.
|
||||
insertMsg(database.ChatMessageRoleAssistant, "hello")
|
||||
require.True(t, getHasUnread(), "chat with unread assistant message should be unread")
|
||||
|
||||
// Mark as read: no longer unread.
|
||||
lastMsg, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
||||
ChatID: chat.ID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = store.UpdateChatLastReadMessageID(ctx, database.UpdateChatLastReadMessageIDParams{
|
||||
ID: chat.ID,
|
||||
LastReadMessageID: lastMsg.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, getHasUnread(), "chat should not be unread after marking as read")
|
||||
|
||||
// Insert another assistant message: becomes unread again.
|
||||
insertMsg(database.ChatMessageRoleAssistant, "new message")
|
||||
require.True(t, getHasUnread(), "new assistant message after read should be unread")
|
||||
|
||||
// Mark as read again, then verify user messages don't
|
||||
// trigger unread.
|
||||
lastMsg, err = store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
||||
ChatID: chat.ID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = store.UpdateChatLastReadMessageID(ctx, database.UpdateChatLastReadMessageIDParams{
|
||||
ID: chat.ID,
|
||||
LastReadMessageID: lastMsg.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
insertMsg(database.ChatMessageRoleUser, "user msg")
|
||||
require.False(t, getHasUnread(), "user messages should not trigger unread")
|
||||
}
|
||||
|
||||
+1201
-137
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
) VALUES (
|
||||
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
@@ -454,95 +454,91 @@ WHERE
|
||||
-- Returns paginated sessions with aggregated metadata, token counts, and
|
||||
-- the most recent user prompt. A "session" is a logical grouping of
|
||||
-- interceptions that share the same session_id (set by the client).
|
||||
WITH filtered_interceptions AS (
|
||||
--
|
||||
-- Pagination-first strategy: identify the page of sessions cheaply via a
|
||||
-- single GROUP BY scan, then do expensive lateral joins (tokens, prompts,
|
||||
-- first-interception metadata) only for the ~page-size result set.
|
||||
WITH cursor_pos AS (
|
||||
-- Resolve the cursor's started_at once, outside the HAVING clause,
|
||||
-- so the planner cannot accidentally re-evaluate it per group.
|
||||
SELECT MIN(aibridge_interceptions.started_at) AS started_at
|
||||
FROM aibridge_interceptions
|
||||
WHERE aibridge_interceptions.session_id = @after_session_id AND aibridge_interceptions.ended_at IS NOT NULL
|
||||
),
|
||||
session_page AS (
|
||||
-- Paginate at the session level first; only cheap aggregates here.
|
||||
SELECT
|
||||
aibridge_interceptions.*
|
||||
ai.session_id,
|
||||
ai.initiator_id,
|
||||
MIN(ai.started_at) AS started_at,
|
||||
MAX(ai.ended_at) AS ended_at,
|
||||
COUNT(*) FILTER (WHERE ai.thread_root_id IS NULL) AS threads
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
aibridge_interceptions ai
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
ai.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN ai.started_at >= @started_after::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN ai.started_at <= @started_before::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN ai.initiator_id = @initiator_id::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
WHEN @provider::text != '' THEN ai.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
WHEN @model::text != '' THEN ai.model = @model::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
|
||||
WHEN @client::text != '' THEN COALESCE(ai.client, 'Unknown') = @client::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
|
||||
WHEN @session_id::text != '' THEN ai.session_id = @session_id::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
),
|
||||
session_tokens AS (
|
||||
-- Aggregate token usage across all interceptions in each session.
|
||||
-- Group by (session_id, initiator_id) to avoid merging sessions from
|
||||
-- different users who happen to share the same client_session_id.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
-- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands.
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON fi.id = tu.interception_id
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
),
|
||||
session_root AS (
|
||||
-- Build one summary row per session. Group by (session_id, initiator_id)
|
||||
-- to avoid merging sessions from different users who happen to share the
|
||||
-- same client_session_id. The ARRAY_AGG with ORDER BY picks values from
|
||||
-- the chronologically first interception for fields that should represent
|
||||
-- the session as a whole (client, metadata). Threads are counted as
|
||||
-- distinct root interception IDs: an interception with a NULL
|
||||
-- thread_root_id is itself a thread root.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
(ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client,
|
||||
(ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata,
|
||||
ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers,
|
||||
ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models,
|
||||
MIN(fi.started_at) AS started_at,
|
||||
MAX(fi.ended_at) AS ended_at,
|
||||
COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads,
|
||||
-- Collect IDs for lateral prompt lookup.
|
||||
ARRAY_AGG(fi.id) AS interception_ids
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
ai.session_id, ai.initiator_id
|
||||
HAVING
|
||||
-- Cursor pagination: uses a composite (started_at, session_id)
|
||||
-- cursor to support keyset pagination. The less-than comparison
|
||||
-- matches the DESC sort order so rows after the cursor come
|
||||
-- later in results. The cursor value comes from cursor_pos to
|
||||
-- guarantee single evaluation.
|
||||
CASE
|
||||
WHEN @after_session_id::text != '' THEN (
|
||||
(MIN(ai.started_at), ai.session_id) < (
|
||||
(SELECT started_at FROM cursor_pos),
|
||||
@after_session_id::text
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
MIN(ai.started_at) DESC,
|
||||
ai.session_id DESC
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
)
|
||||
SELECT
|
||||
sr.session_id,
|
||||
sp.session_id,
|
||||
visible_users.id AS user_id,
|
||||
visible_users.username AS user_username,
|
||||
visible_users.name AS user_name,
|
||||
@@ -551,47 +547,114 @@ SELECT
|
||||
sr.models::text[] AS models,
|
||||
COALESCE(sr.client, '')::varchar(64) AS client,
|
||||
sr.metadata::jsonb AS metadata,
|
||||
sr.started_at::timestamptz AS started_at,
|
||||
sr.ended_at::timestamptz AS ended_at,
|
||||
sr.threads,
|
||||
sp.started_at::timestamptz AS started_at,
|
||||
sp.ended_at::timestamptz AS ended_at,
|
||||
sp.threads,
|
||||
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
|
||||
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
|
||||
COALESCE(slp.prompt, '') AS last_prompt
|
||||
FROM
|
||||
session_root sr
|
||||
session_page sp
|
||||
JOIN
|
||||
visible_users ON visible_users.id = sr.initiator_id
|
||||
LEFT JOIN
|
||||
session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id
|
||||
visible_users ON visible_users.id = sp.initiator_id
|
||||
LEFT JOIN LATERAL (
|
||||
-- Lateral join to efficiently fetch only the most recent user prompt
|
||||
-- across all interceptions in the session, avoiding a full aggregation.
|
||||
SELECT
|
||||
(ARRAY_AGG(ai.client ORDER BY ai.started_at, ai.id))[1] AS client,
|
||||
(ARRAY_AGG(ai.metadata ORDER BY ai.started_at, ai.id))[1] AS metadata,
|
||||
ARRAY_AGG(DISTINCT ai.provider ORDER BY ai.provider) AS providers,
|
||||
ARRAY_AGG(DISTINCT ai.model ORDER BY ai.model) AS models,
|
||||
ARRAY_AGG(ai.id) AS interception_ids
|
||||
FROM aibridge_interceptions ai
|
||||
WHERE ai.session_id = sp.session_id
|
||||
AND ai.initiator_id = sp.initiator_id
|
||||
AND ai.ended_at IS NOT NULL
|
||||
) sr ON true
|
||||
LEFT JOIN LATERAL (
|
||||
-- Aggregate tokens only for this session's interceptions.
|
||||
SELECT
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
FROM aibridge_token_usages tu
|
||||
WHERE tu.interception_id = ANY(sr.interception_ids)
|
||||
) st ON true
|
||||
LEFT JOIN LATERAL (
|
||||
-- Fetch only the most recent user prompt across all interceptions
|
||||
-- in the session.
|
||||
SELECT up.prompt
|
||||
FROM aibridge_user_prompts up
|
||||
WHERE up.interception_id = ANY(sr.interception_ids)
|
||||
ORDER BY up.created_at DESC, up.id DESC
|
||||
LIMIT 1
|
||||
) slp ON true
|
||||
WHERE
|
||||
-- Cursor pagination: uses a composite (started_at, session_id) cursor
|
||||
-- to support keyset pagination. The less-than comparison matches the
|
||||
-- DESC sort order so that rows after the cursor come later in results.
|
||||
CASE
|
||||
WHEN @after_session_id::text != '' THEN (
|
||||
(sr.started_at, sr.session_id) < (
|
||||
(SELECT started_at FROM session_root WHERE session_id = @after_session_id),
|
||||
@after_session_id::text
|
||||
ORDER BY
|
||||
sp.started_at DESC,
|
||||
sp.session_id DESC
|
||||
;
|
||||
|
||||
-- name: ListAIBridgeSessionThreads :many
|
||||
-- Returns all interceptions belonging to paginated threads within a session.
|
||||
-- Threads are paginated by (started_at, thread_id) cursor.
|
||||
WITH paginated_threads AS (
|
||||
SELECT
|
||||
-- Find thread root interceptions (thread_root_id IS NULL), apply cursor
|
||||
-- pagination, and return the page.
|
||||
aibridge_interceptions.id AS thread_id,
|
||||
aibridge_interceptions.started_at
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
aibridge_interceptions.session_id = @session_id::text
|
||||
AND aibridge_interceptions.ended_at IS NOT NULL
|
||||
AND aibridge_interceptions.thread_root_id IS NULL
|
||||
-- Pagination cursor.
|
||||
AND (@after_id::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR
|
||||
(aibridge_interceptions.started_at, aibridge_interceptions.id) > (
|
||||
(SELECT started_at FROM aibridge_interceptions ai2 WHERE ai2.id = @after_id),
|
||||
@after_id::uuid
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
AND (@before_id::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR
|
||||
(aibridge_interceptions.started_at, aibridge_interceptions.id) < (
|
||||
(SELECT started_at FROM aibridge_interceptions ai2 WHERE ai2.id = @before_id),
|
||||
@before_id::uuid
|
||||
)
|
||||
)
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
aibridge_interceptions.started_at ASC,
|
||||
aibridge_interceptions.id ASC
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 50)
|
||||
)
|
||||
SELECT
|
||||
COALESCE(aibridge_interceptions.thread_root_id, aibridge_interceptions.id) AS thread_id,
|
||||
sqlc.embed(aibridge_interceptions)
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
JOIN
|
||||
paginated_threads pt
|
||||
ON pt.thread_id = COALESCE(aibridge_interceptions.thread_root_id, aibridge_interceptions.id)
|
||||
WHERE
|
||||
aibridge_interceptions.session_id = @session_id::text
|
||||
AND aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
sr.started_at DESC,
|
||||
sr.session_id DESC
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
-- Ensure threads and their associated interceptions (agentic loops) are sorted chronologically.
|
||||
pt.started_at ASC,
|
||||
pt.thread_id ASC,
|
||||
aibridge_interceptions.started_at ASC,
|
||||
aibridge_interceptions.id ASC
|
||||
;
|
||||
|
||||
-- name: ListAIBridgeModelThoughtsByInterceptionIDs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
aibridge_model_thoughts
|
||||
WHERE
|
||||
interception_id = ANY(@interception_ids::uuid[])
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: ListAIBridgeModels :many
|
||||
SELECT
|
||||
model
|
||||
@@ -616,3 +679,27 @@ ORDER BY
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
;
|
||||
|
||||
|
||||
-- name: ListAIBridgeClients :many
|
||||
SELECT
|
||||
COALESCE(client, 'Unknown') AS client
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
ended_at IS NOT NULL
|
||||
-- Filter client (prefix match to allow B-tree index usage).
|
||||
AND CASE
|
||||
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') LIKE @client::text || '%'
|
||||
ELSE true
|
||||
END
|
||||
-- We use an `@authorize_filter` as we are attempting to list clients
|
||||
-- that are relevant to the user and what they are allowed to see.
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- ListAIBridgeClientsAuthorized.
|
||||
-- @authorize_filter
|
||||
GROUP BY
|
||||
client
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
;
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
-- name: GetUserAISeatStates :many
|
||||
-- Returns user IDs from the provided list that are consuming an AI seat.
|
||||
-- Filters to active, non-deleted, non-system users to match the canonical
|
||||
-- seat count query (GetActiveAISeatCount).
|
||||
SELECT
|
||||
ais.user_id
|
||||
FROM
|
||||
ai_seat_state ais
|
||||
JOIN
|
||||
users u
|
||||
ON
|
||||
ais.user_id = u.id
|
||||
WHERE
|
||||
ais.user_id = ANY(@user_ids::uuid[])
|
||||
AND u.status = 'active'::user_status
|
||||
AND u.deleted = false
|
||||
AND u.is_system = false;
|
||||
@@ -1,9 +1,192 @@
|
||||
-- name: ArchiveChatByID :exec
|
||||
UPDATE chats SET archived = true, updated_at = NOW()
|
||||
WHERE id = @id OR root_chat_id = @id;
|
||||
-- name: ArchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
SELECT *
|
||||
FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: UnarchiveChatByID :exec
|
||||
UPDATE chats SET archived = false, updated_at = NOW() WHERE id = @id::uuid;
|
||||
-- name: UnarchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
SELECT *
|
||||
FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: PinChatByID :exec
|
||||
WITH target_chat AS (
|
||||
SELECT
|
||||
id,
|
||||
owner_id
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
),
|
||||
-- Under READ COMMITTED, concurrent pin operations for the same
|
||||
-- owner may momentarily produce duplicate pin_order values because
|
||||
-- each CTE snapshot does not see the other's writes. The next
|
||||
-- pin/unpin/reorder operation's ROW_NUMBER() self-heals the
|
||||
-- sequence, so this is acceptable.
|
||||
ranked AS (
|
||||
SELECT
|
||||
c.id,
|
||||
ROW_NUMBER() OVER (ORDER BY c.pin_order ASC, c.id ASC) :: integer AS next_pin_order
|
||||
FROM
|
||||
chats c
|
||||
JOIN
|
||||
target_chat ON c.owner_id = target_chat.owner_id
|
||||
WHERE
|
||||
c.pin_order > 0
|
||||
AND c.archived = FALSE
|
||||
AND c.id <> target_chat.id
|
||||
),
|
||||
updates AS (
|
||||
SELECT
|
||||
ranked.id,
|
||||
ranked.next_pin_order AS pin_order
|
||||
FROM
|
||||
ranked
|
||||
UNION ALL
|
||||
SELECT
|
||||
target_chat.id,
|
||||
COALESCE((
|
||||
SELECT
|
||||
MAX(ranked.next_pin_order)
|
||||
FROM
|
||||
ranked
|
||||
), 0) + 1 AS pin_order
|
||||
FROM
|
||||
target_chat
|
||||
)
|
||||
UPDATE
|
||||
chats c
|
||||
SET
|
||||
pin_order = updates.pin_order
|
||||
FROM
|
||||
updates
|
||||
WHERE
|
||||
c.id = updates.id;
|
||||
|
||||
-- name: UnpinChatByID :exec
|
||||
WITH target_chat AS (
|
||||
SELECT
|
||||
id,
|
||||
owner_id
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
),
|
||||
ranked AS (
|
||||
SELECT
|
||||
c.id,
|
||||
ROW_NUMBER() OVER (ORDER BY c.pin_order ASC, c.id ASC) :: integer AS current_position
|
||||
FROM
|
||||
chats c
|
||||
JOIN
|
||||
target_chat ON c.owner_id = target_chat.owner_id
|
||||
WHERE
|
||||
c.pin_order > 0
|
||||
AND c.archived = FALSE
|
||||
),
|
||||
target AS (
|
||||
SELECT
|
||||
ranked.id,
|
||||
ranked.current_position
|
||||
FROM
|
||||
ranked
|
||||
WHERE
|
||||
ranked.id = @id::uuid
|
||||
),
|
||||
updates AS (
|
||||
SELECT
|
||||
ranked.id,
|
||||
CASE
|
||||
WHEN ranked.id = target.id THEN 0
|
||||
WHEN ranked.current_position > target.current_position THEN ranked.current_position - 1
|
||||
ELSE ranked.current_position
|
||||
END AS pin_order
|
||||
FROM
|
||||
ranked
|
||||
CROSS JOIN
|
||||
target
|
||||
)
|
||||
UPDATE
|
||||
chats c
|
||||
SET
|
||||
pin_order = updates.pin_order
|
||||
FROM
|
||||
updates
|
||||
WHERE
|
||||
c.id = updates.id;
|
||||
|
||||
-- name: UpdateChatPinOrder :exec
|
||||
WITH target_chat AS (
|
||||
SELECT
|
||||
id,
|
||||
owner_id
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
),
|
||||
ranked AS (
|
||||
SELECT
|
||||
c.id,
|
||||
ROW_NUMBER() OVER (ORDER BY c.pin_order ASC, c.id ASC) :: integer AS current_position,
|
||||
COUNT(*) OVER () :: integer AS pinned_count
|
||||
FROM
|
||||
chats c
|
||||
JOIN
|
||||
target_chat ON c.owner_id = target_chat.owner_id
|
||||
WHERE
|
||||
c.pin_order > 0
|
||||
AND c.archived = FALSE
|
||||
),
|
||||
target AS (
|
||||
SELECT
|
||||
ranked.id,
|
||||
ranked.current_position,
|
||||
LEAST(GREATEST(@pin_order::integer, 1), ranked.pinned_count) AS desired_position
|
||||
FROM
|
||||
ranked
|
||||
WHERE
|
||||
ranked.id = @id::uuid
|
||||
),
|
||||
updates AS (
|
||||
SELECT
|
||||
ranked.id,
|
||||
CASE
|
||||
WHEN ranked.id = target.id THEN target.desired_position
|
||||
WHEN target.desired_position < target.current_position
|
||||
AND ranked.current_position >= target.desired_position
|
||||
AND ranked.current_position < target.current_position THEN ranked.current_position + 1
|
||||
WHEN target.desired_position > target.current_position
|
||||
AND ranked.current_position > target.current_position
|
||||
AND ranked.current_position <= target.desired_position THEN ranked.current_position - 1
|
||||
ELSE ranked.current_position
|
||||
END AS pin_order
|
||||
FROM
|
||||
ranked
|
||||
CROSS JOIN
|
||||
target
|
||||
)
|
||||
UPDATE
|
||||
chats c
|
||||
SET
|
||||
pin_order = updates.pin_order
|
||||
FROM
|
||||
updates
|
||||
WHERE
|
||||
c.id = updates.id;
|
||||
|
||||
-- name: SoftDeleteChatMessagesAfterID :exec
|
||||
UPDATE
|
||||
@@ -52,6 +235,21 @@ WHERE
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: GetChatMessagesByChatIDAscPaginated :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND id > @after_id::bigint
|
||||
AND visibility IN ('user', 'both')
|
||||
AND deleted = false
|
||||
ORDER BY
|
||||
id ASC
|
||||
LIMIT
|
||||
COALESCE(NULLIF(@limit_val::int, 0), 50);
|
||||
|
||||
-- name: GetChatMessagesByChatIDDescPaginated :many
|
||||
SELECT
|
||||
*
|
||||
@@ -130,7 +328,14 @@ ORDER BY
|
||||
|
||||
-- name: GetChats :many
|
||||
SELECT
|
||||
*
|
||||
sqlc.embed(chats),
|
||||
EXISTS (
|
||||
SELECT 1 FROM chat_messages cm
|
||||
WHERE cm.chat_id = chats.id
|
||||
AND cm.role = 'assistant'
|
||||
AND cm.deleted = false
|
||||
AND cm.id > COALESCE(chats.last_read_message_id, 0)
|
||||
) AS has_unread
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -180,21 +385,27 @@ LIMIT
|
||||
INSERT INTO chats (
|
||||
owner_id,
|
||||
workspace_id,
|
||||
build_id,
|
||||
agent_id,
|
||||
parent_chat_id,
|
||||
root_chat_id,
|
||||
last_model_config_id,
|
||||
title,
|
||||
mode,
|
||||
status,
|
||||
mcp_server_ids,
|
||||
labels
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
sqlc.narg('build_id')::uuid,
|
||||
sqlc.narg('agent_id')::uuid,
|
||||
sqlc.narg('parent_chat_id')::uuid,
|
||||
sqlc.narg('root_chat_id')::uuid,
|
||||
@last_model_config_id::uuid,
|
||||
@title::text,
|
||||
sqlc.narg('mode')::chat_mode,
|
||||
@status::chat_status,
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb)
|
||||
)
|
||||
@@ -294,6 +505,17 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatLastModelConfigByID :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
-- NOTE: updated_at is intentionally NOT touched here to avoid changing list ordering.
|
||||
last_model_config_id = @last_model_config_id::uuid
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatLabelsByID :one
|
||||
UPDATE
|
||||
chats
|
||||
@@ -305,16 +527,34 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatWorkspace :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
-- name: UpdateChatWorkspaceBinding :one
|
||||
UPDATE chats SET
|
||||
workspace_id = sqlc.narg('workspace_id')::uuid,
|
||||
build_id = sqlc.narg('build_id')::uuid,
|
||||
agent_id = sqlc.narg('agent_id')::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatBuildAgentBinding :one
|
||||
UPDATE chats SET
|
||||
build_id = sqlc.narg('build_id')::uuid,
|
||||
agent_id = sqlc.narg('agent_id')::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatLastInjectedContext :one
|
||||
-- Updates the cached injected context parts (AGENTS.md +
|
||||
-- skills) on the chat row. Called only when context changes
|
||||
-- (first workspace attach or agent change). updated_at is
|
||||
-- intentionally not touched to avoid reordering the chat list.
|
||||
UPDATE chats SET
|
||||
last_injected_context = sqlc.narg('last_injected_context')::jsonb
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatMCPServerIDs :one
|
||||
UPDATE
|
||||
@@ -371,6 +611,21 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatStatusPreserveUpdatedAt :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
status = @status::chat_status,
|
||||
worker_id = sqlc.narg('worker_id')::uuid,
|
||||
started_at = sqlc.narg('started_at')::timestamptz,
|
||||
heartbeat_at = sqlc.narg('heartbeat_at')::timestamptz,
|
||||
last_error = sqlc.narg('last_error')::text,
|
||||
updated_at = @updated_at::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetStaleChats :many
|
||||
-- Find chats that appear stuck (running but heartbeat has expired).
|
||||
-- Used for recovery after coderd crashes or long hangs.
|
||||
@@ -878,6 +1133,13 @@ JOIN group_members_expanded gme ON gme.group_id = g.id
|
||||
WHERE gme.user_id = @user_id::uuid
|
||||
AND g.chat_spend_limit_micros IS NOT NULL;
|
||||
|
||||
-- name: GetChatsByWorkspaceIDs :many
|
||||
SELECT *
|
||||
FROM chats
|
||||
WHERE archived = false
|
||||
AND workspace_id = ANY(@ids::uuid[])
|
||||
ORDER BY workspace_id, updated_at DESC;
|
||||
|
||||
-- name: ResolveUserChatSpendLimit :one
|
||||
-- Resolves the effective spend limit for a user using the hierarchy:
|
||||
-- 1. Individual user override (highest priority)
|
||||
@@ -905,3 +1167,10 @@ LEFT JOIN LATERAL (
|
||||
) gl ON TRUE
|
||||
WHERE u.id = @user_id::uuid
|
||||
LIMIT 1;
|
||||
|
||||
-- name: UpdateChatLastReadMessageID :exec
|
||||
-- Updates the last read message ID for a chat. This is used to track
|
||||
-- which messages the owner has seen, enabling unread indicators.
|
||||
UPDATE chats
|
||||
SET last_read_message_id = @last_read_message_id::bigint
|
||||
WHERE id = @id::uuid;
|
||||
|
||||
@@ -77,6 +77,7 @@ INSERT INTO mcp_server_configs (
|
||||
tool_deny_list,
|
||||
availability,
|
||||
enabled,
|
||||
model_intent,
|
||||
created_by,
|
||||
updated_by
|
||||
) VALUES (
|
||||
@@ -102,6 +103,7 @@ INSERT INTO mcp_server_configs (
|
||||
@tool_deny_list::text[],
|
||||
@availability::text,
|
||||
@enabled::boolean,
|
||||
@model_intent::boolean,
|
||||
@created_by::uuid,
|
||||
@updated_by::uuid
|
||||
)
|
||||
@@ -134,6 +136,7 @@ SET
|
||||
tool_deny_list = @tool_deny_list::text[],
|
||||
availability = @availability::text,
|
||||
enabled = @enabled::boolean,
|
||||
model_intent = @model_intent::boolean,
|
||||
updated_by = @updated_by::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
|
||||
@@ -137,6 +137,24 @@ SELECT
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_system_prompt'), '') :: text AS chat_system_prompt;
|
||||
|
||||
-- GetChatSystemPromptConfig returns both chat system prompt settings in a
|
||||
-- single read to avoid torn reads between separate site-config lookups.
|
||||
-- The include-default fallback preserves the legacy behavior where a
|
||||
-- non-empty custom prompt implied opting out before the explicit toggle
|
||||
-- existed.
|
||||
-- name: GetChatSystemPromptConfig :one
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_system_prompt'), '') :: text AS chat_system_prompt,
|
||||
COALESCE(
|
||||
(SELECT value = 'true' FROM site_configs WHERE key = 'agents_chat_include_default_system_prompt'),
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM site_configs
|
||||
WHERE key = 'agents_chat_system_prompt'
|
||||
AND value != ''
|
||||
)
|
||||
) :: boolean AS include_default_system_prompt;
|
||||
|
||||
-- name: UpsertChatSystemPrompt :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_system_prompt', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_system_prompt';
|
||||
@@ -167,6 +185,38 @@ WHERE site_configs.key = 'agents_desktop_enabled';
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_template_allowlist'), '') :: text AS template_allowlist;
|
||||
|
||||
-- GetChatIncludeDefaultSystemPrompt preserves the legacy default
|
||||
-- for deployments created before the explicit include-default toggle.
|
||||
-- When the toggle is unset, a non-empty custom prompt implies false;
|
||||
-- otherwise the setting defaults to true.
|
||||
-- name: GetChatIncludeDefaultSystemPrompt :one
|
||||
SELECT
|
||||
COALESCE(
|
||||
(SELECT value = 'true' FROM site_configs WHERE key = 'agents_chat_include_default_system_prompt'),
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM site_configs
|
||||
WHERE key = 'agents_chat_system_prompt'
|
||||
AND value != ''
|
||||
)
|
||||
) :: boolean AS include_default_system_prompt;
|
||||
|
||||
-- name: UpsertChatIncludeDefaultSystemPrompt :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES (
|
||||
'agents_chat_include_default_system_prompt',
|
||||
CASE
|
||||
WHEN sqlc.arg(include_default_system_prompt)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = CASE
|
||||
WHEN sqlc.arg(include_default_system_prompt)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
WHERE site_configs.key = 'agents_chat_include_default_system_prompt';
|
||||
|
||||
-- name: GetChatWorkspaceTTL :one
|
||||
-- Returns the global TTL for chat workspaces as a Go duration string.
|
||||
-- Returns "0s" (disabled) when no value has been configured.
|
||||
|
||||
+357
-177
@@ -33,6 +33,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -55,7 +56,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
chatDiffStatusTTL = gitsync.DiffStatusTTL
|
||||
chatStreamBatchSize = 256
|
||||
|
||||
chatContextLimitModelConfigKey = "context_limit"
|
||||
@@ -66,11 +66,12 @@ const (
|
||||
maxSystemPromptLenBytes = 131072 // 128 KiB
|
||||
)
|
||||
|
||||
// chatGitRef holds the branch and remote origin reported by the
|
||||
// workspace agent during a git operation.
|
||||
// chatGitRef holds the branch, remote origin, and optional chat
|
||||
// ID reported by the workspace agent during a git operation.
|
||||
type chatGitRef struct {
|
||||
Branch string
|
||||
RemoteOrigin string
|
||||
ChatID uuid.UUID
|
||||
}
|
||||
|
||||
type chatRepositoryRef struct {
|
||||
@@ -110,6 +111,28 @@ func maybeWriteLimitErr(ctx context.Context, rw http.ResponseWriter, err error)
|
||||
return false
|
||||
}
|
||||
|
||||
func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.ChatConfigEventKind, entityID uuid.UUID) {
|
||||
payload, err := json.Marshal(pubsub.ChatConfigEvent{
|
||||
Kind: kind,
|
||||
EntityID: entityID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(context.Background(), "failed to marshal chat config event",
|
||||
slog.F("kind", kind),
|
||||
slog.F("entity_id", entityID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := ps.Publish(pubsub.ChatConfigEventChannel, payload); err != nil {
|
||||
logger.Error(context.Background(), "failed to publish chat config event",
|
||||
slog.F("kind", kind),
|
||||
slog.F("entity_id", entityID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
@@ -173,6 +196,88 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
|
||||
// the latest non-archived chat ID for each requested workspace.
|
||||
// The query returns all matching chats and RBAC post-filters them;
|
||||
// the handler then picks the latest per workspace in Go. This avoids
|
||||
// the DISTINCT ON + post-filter bug where the sole candidate is
|
||||
// silently dropped when the caller can't read it.
|
||||
//
|
||||
// TODO:
|
||||
// 1. move aggregation to a SQL view with proper in-query authz so we
|
||||
// can return a single row per workspace without this two-pass approach.
|
||||
// 2. Restore the below router annotation and un-skip docs gen
|
||||
// <at>Router /experimental/chats/by-workspace [post]
|
||||
//
|
||||
// @Summary Get latest chats by workspace IDs
|
||||
// @ID get-latest-chats-by-workspace-ids
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) chatsByWorkspace(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
idsParam := r.URL.Query().Get("workspace_ids")
|
||||
if idsParam == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, map[uuid.UUID]uuid.UUID{})
|
||||
return
|
||||
}
|
||||
|
||||
raw := strings.Split(idsParam, ",")
|
||||
|
||||
// maxWorkspaceIDs is coupled to DEFAULT_RECORDS_PER_PAGE (25) in
|
||||
// site/src/components/PaginationWidget/utils.ts.
|
||||
// If the page size changes, this limit should too.
|
||||
const maxWorkspaceIDs = 25
|
||||
if len(raw) > maxWorkspaceIDs {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Too many workspace IDs, maximum is %d.", maxWorkspaceIDs),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
workspaceIDs := make([]uuid.UUID, 0, len(raw))
|
||||
for _, s := range raw {
|
||||
id, err := uuid.Parse(strings.TrimSpace(s))
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Invalid workspace ID %q: %s", s, err),
|
||||
})
|
||||
return
|
||||
}
|
||||
workspaceIDs = append(workspaceIDs, id)
|
||||
}
|
||||
|
||||
chats, err := api.Database.GetChatsByWorkspaceIDs(ctx, workspaceIDs)
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
} else if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chats by workspace.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// The SQL orders by (workspace_id, updated_at DESC), so the first
|
||||
// chat seen per workspace after RBAC filtering is the latest
|
||||
// readable one.
|
||||
result := make(map[uuid.UUID]uuid.UUID, len(chats))
|
||||
for _, chat := range chats {
|
||||
if chat.WorkspaceID.Valid {
|
||||
if _, exists := result[chat.WorkspaceID.UUID]; !exists {
|
||||
result[chat.WorkspaceID.UUID] = chat.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, result)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
@@ -231,7 +336,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
||||
LimitOpt: int32(paginationParams.Limit),
|
||||
}
|
||||
|
||||
chats, err := api.Database.GetChats(ctx, params)
|
||||
chatRows, err := api.Database.GetChats(ctx, params)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list chats.",
|
||||
@@ -240,7 +345,13 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
diffStatusesByChatID, err := api.getChatDiffStatusesByChatID(ctx, chats)
|
||||
// Extract the Chat objects for diff status lookup.
|
||||
dbChats := make([]database.Chat, len(chatRows))
|
||||
for i, row := range chatRows {
|
||||
dbChats[i] = row.Chat
|
||||
}
|
||||
|
||||
diffStatusesByChatID, err := api.getChatDiffStatusesByChatID(ctx, dbChats)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list chats.",
|
||||
@@ -249,7 +360,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertChats(chats, diffStatusesByChatID))
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.ChatRows(chatRows, diffStatusesByChatID))
|
||||
}
|
||||
|
||||
func (api *API) getChatDiffStatusesByChatID(
|
||||
@@ -282,6 +393,11 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String())) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.CreateChatRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
@@ -387,6 +503,10 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to create chat.",
|
||||
Detail: err.Error(),
|
||||
@@ -394,7 +514,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, convertChat(chat, nil))
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.Chat(chat, nil))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -505,6 +625,10 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
EndDate: endDate,
|
||||
})
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
@@ -515,6 +639,10 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
EndDate: endDate,
|
||||
})
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
@@ -525,6 +653,10 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
EndDate: endDate,
|
||||
})
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
@@ -1111,15 +1243,23 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
diffStatus, err := api.resolveChatDiffStatus(ctx, chat)
|
||||
if err != nil {
|
||||
// Log but don't fail - diff status is supplementary.
|
||||
api.Logger.Error(ctx, "failed to resolve chat diff status",
|
||||
// Use the cached diff status from the database rather than
|
||||
// resolving it inline. Inline resolution calls out to the
|
||||
// git provider API (e.g. GitHub) on every request which
|
||||
// blocks the response for 200-800ms. The background gitsync
|
||||
// worker keeps the cached status fresh.
|
||||
var diffStatus *database.ChatDiffStatus
|
||||
status, err := api.Database.GetChatDiffStatusByChatID(ctx, chat.ID)
|
||||
switch {
|
||||
case err == nil:
|
||||
diffStatus = &status
|
||||
case !xerrors.Is(err, sql.ErrNoRows):
|
||||
api.Logger.Error(ctx, "failed to get cached chat diff status",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertChat(chat, diffStatus))
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, diffStatus))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -1450,8 +1590,8 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
|
||||
logger.Debug(ctx, "desktop Bicopy finished")
|
||||
}
|
||||
|
||||
// patchChat updates a chat resource. Supports updating labels and
|
||||
// toggling the archived state.
|
||||
// patchChat updates a chat resource. Supports updating labels,
|
||||
// archiving, pinning, and pinned-chat ordering.
|
||||
func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
@@ -1509,20 +1649,20 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify active
|
||||
// subscribers. Fall back to direct DB for the simple
|
||||
// archive flag — no streaming state is involved.
|
||||
// Use chatDaemon when available so it can interrupt active
|
||||
// processing before broadcasting archive state. Fall back to
|
||||
// direct DB when no daemon is running.
|
||||
if archived {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
_, err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
} else {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.UnarchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
_, err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -1538,6 +1678,54 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
if req.PinOrder != nil {
|
||||
pinOrder := *req.PinOrder
|
||||
if pinOrder < 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Pin order must be non-negative.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if pinOrder > 0 && chat.Archived {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Cannot pin an archived chat.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// The behavior depends on current pin state:
|
||||
// - pinOrder == 0: unpin.
|
||||
// - pinOrder > 0 && already pinned: reorder (shift
|
||||
// neighbors, clamp to [1, count]).
|
||||
// - pinOrder > 0 && not pinned: append to end. The
|
||||
// requested value is intentionally ignored because
|
||||
// PinChatByID also bumps updated_at to keep the
|
||||
// chat visible in the paginated sidebar.
|
||||
var err error
|
||||
errMsg := "Failed to pin chat."
|
||||
switch {
|
||||
case pinOrder == 0:
|
||||
errMsg = "Failed to unpin chat."
|
||||
err = api.Database.UnpinChatByID(ctx, chat.ID)
|
||||
case chat.PinOrder > 0:
|
||||
errMsg = "Failed to reorder pinned chat."
|
||||
err = api.Database.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
||||
ID: chat.ID,
|
||||
PinOrder: pinOrder,
|
||||
})
|
||||
default:
|
||||
err = api.Database.PinChatByID(ctx, chat.ID)
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: errMsg,
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@@ -1794,6 +1982,39 @@ func (api *API) promoteChatQueuedMessage(rw http.ResponseWriter, r *http.Request
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertChatMessage(promoteResult.PromotedMessage))
|
||||
}
|
||||
|
||||
// markChatAsRead updates the last read message ID for a chat to the
|
||||
// latest message, so subsequent unread checks treat all current
|
||||
// messages as seen. This is called on stream connect and disconnect
|
||||
// to avoid per-message API calls during active streaming.
|
||||
func (api *API) markChatAsRead(ctx context.Context, chatID uuid.UUID) {
|
||||
lastMsg, err := api.Database.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// No assistant messages yet, nothing to mark as read.
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to get last assistant message for read marker",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Database.UpdateChatLastReadMessageID(ctx, database.UpdateChatLastReadMessageIDParams{
|
||||
ID: chatID,
|
||||
LastReadMessageID: lastMsg.ID,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to update chat last read message ID",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
@@ -1850,6 +2071,12 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
defer cancel()
|
||||
|
||||
// Mark the chat as read when the stream connects and again
|
||||
// when it disconnects so we avoid per-message API calls while
|
||||
// messages are actively streaming.
|
||||
api.markChatAsRead(ctx, chatID)
|
||||
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
@@ -1949,7 +2176,51 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chat = updatedChat
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertChat(chat, nil))
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if api.chatDaemon == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat processor is unavailable.",
|
||||
Detail: "Chat processor is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
updatedChat, err := api.chatDaemon.RegenerateChatTitle(ctx, chat)
|
||||
if err != nil {
|
||||
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Title regeneration already in progress for this chat.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
return
|
||||
}
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to regenerate chat title.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -2095,68 +2366,6 @@ func chatWorkspaceAuditStatus(err error) int {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func (api *API) resolveChatDiffStatus(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
) (*database.ChatDiffStatus, error) {
|
||||
status, found, err := api.getCachedChatDiffStatus(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
reference, err := api.resolveChatDiffReference(ctx, chat, found, status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reference.PullRequestURL != "" {
|
||||
if !found || !strings.EqualFold(strings.TrimSpace(status.Url.String), reference.PullRequestURL) {
|
||||
status, err = api.upsertChatDiffStatusReference(ctx, chat.ID, reference.PullRequestURL, now.Add(-time.Second))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, nil //nolint:nilnil // Callers handle nil status explicitly.
|
||||
}
|
||||
if !chatDiffStatusIsStale(status, now) {
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// Use the same refresh pipeline as the background worker
|
||||
// so both paths share identical provider/token resolution.
|
||||
refreshed, err := api.gitSyncWorker.RefreshChat(
|
||||
ctx, status, chat.OwnerID,
|
||||
)
|
||||
if err == nil && refreshed != nil {
|
||||
return refreshed, nil
|
||||
}
|
||||
if err == nil {
|
||||
// No PR exists yet; return what we have.
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
api.Logger.Warn(ctx, "failed to refresh chat diff status",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
|
||||
backoffStatus, backoffErr := api.upsertChatDiffStatusReference(ctx, chat.ID, reference.PullRequestURL, now.Add(chatDiffStatusTTL))
|
||||
if backoffErr != nil {
|
||||
api.Logger.Warn(ctx, "failed to extend chat diff status stale timestamp",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(backoffErr),
|
||||
)
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
return &backoffStatus, nil
|
||||
}
|
||||
|
||||
func (api *API) resolveChatDiffContents(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
@@ -2423,13 +2632,6 @@ func (api *API) resolveGitProvider(origin string) gitprovider.Provider {
|
||||
return gp
|
||||
}
|
||||
|
||||
func chatDiffStatusIsStale(status database.ChatDiffStatus, now time.Time) bool {
|
||||
if !status.RefreshedAt.Valid {
|
||||
return true
|
||||
}
|
||||
return !status.StaleAt.After(now)
|
||||
}
|
||||
|
||||
func (api *API) resolveChatGitAccessToken(
|
||||
ctx context.Context,
|
||||
userID uuid.UUID,
|
||||
@@ -2680,25 +2882,35 @@ func detectChatFileType(data []byte) string {
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
prompt, err := api.Database.GetChatSystemPrompt(ctx)
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
config, err := api.Database.GetChatSystemPromptConfig(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching chat system prompt.",
|
||||
Message: "Internal error fetching chat system prompt configuration.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPrompt{
|
||||
SystemPrompt: prompt,
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPromptResponse{
|
||||
SystemPrompt: config.ChatSystemPrompt,
|
||||
IncludeDefaultSystemPrompt: config.IncludeDefaultSystemPrompt,
|
||||
DefaultSystemPrompt: chatd.DefaultSystemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
// Cap the raw request body to prevent excessive memory use from
|
||||
// payloads padded with invisible characters that sanitize away.
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
||||
var req codersdk.ChatSystemPrompt
|
||||
var req codersdk.UpdateChatSystemPromptRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
@@ -2712,13 +2924,23 @@ func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err := api.Database.UpsertChatSystemPrompt(ctx, sanitizedPrompt)
|
||||
if httpapi.Is404Error(err) { // also catches authz error
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
} else if err != nil {
|
||||
err := api.Database.InTx(func(tx database.Store) error {
|
||||
if err := tx.UpsertChatSystemPrompt(ctx, sanitizedPrompt); err != nil {
|
||||
return err
|
||||
}
|
||||
// Only update the include-default flag when the caller explicitly
|
||||
// provides it. Omitting the field preserves whatever is currently
|
||||
// stored (or the schema-level default for new deployments),
|
||||
// avoiding a backward-compatibility regression for older clients
|
||||
// that only send system_prompt.
|
||||
if req.IncludeDefaultSystemPrompt != nil {
|
||||
return tx.UpsertChatIncludeDefaultSystemPrompt(ctx, *req.IncludeDefaultSystemPrompt)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating chat system prompt.",
|
||||
Message: "Internal error updating chat system prompt configuration.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
@@ -3052,6 +3274,8 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventUserPrompt, apiKey.UserID)
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
|
||||
CustomPrompt: updatedConfig.Value,
|
||||
})
|
||||
@@ -3222,21 +3446,32 @@ func (api *API) deleteUserChatCompactionThreshold(rw http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
func (api *API) resolvedChatSystemPrompt(ctx context.Context) string {
|
||||
custom, err := api.Database.GetChatSystemPrompt(ctx)
|
||||
config, err := api.Database.GetChatSystemPromptConfig(ctx)
|
||||
if err != nil {
|
||||
// Log but don't fail chat creation — fall back to the
|
||||
// built-in default so the user isn't blocked.
|
||||
api.Logger.Error(ctx, "failed to fetch custom chat system prompt, using default", slog.Error(err))
|
||||
// We intentionally fail open here. When the prompt configuration
|
||||
// cannot be read, returning the built-in default keeps the chat
|
||||
// grounded instead of sending no system guidance at all.
|
||||
api.Logger.Error(ctx, "failed to fetch chat system prompt configuration, using default", slog.Error(err))
|
||||
return chatd.DefaultSystemPrompt
|
||||
}
|
||||
sanitized := chatd.SanitizePromptText(custom)
|
||||
if sanitized == "" && strings.TrimSpace(custom) != "" {
|
||||
api.Logger.Warn(ctx, "custom system prompt became empty after sanitization, using default")
|
||||
|
||||
sanitizedCustom := chatd.SanitizePromptText(config.ChatSystemPrompt)
|
||||
if sanitizedCustom == "" && strings.TrimSpace(config.ChatSystemPrompt) != "" {
|
||||
api.Logger.Warn(ctx, "custom system prompt became empty after sanitization, omitting custom portion")
|
||||
}
|
||||
if sanitized != "" {
|
||||
return sanitized
|
||||
|
||||
var parts []string
|
||||
if config.IncludeDefaultSystemPrompt {
|
||||
parts = append(parts, chatd.DefaultSystemPrompt)
|
||||
}
|
||||
return chatd.DefaultSystemPrompt
|
||||
if sanitizedCustom != "" {
|
||||
parts = append(parts, sanitizedCustom)
|
||||
}
|
||||
result := strings.Join(parts, "\n\n")
|
||||
if result == "" {
|
||||
api.Logger.Warn(ctx, "resolved system prompt is empty, no system prompt will be injected into chats")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) {
|
||||
@@ -3558,73 +3793,6 @@ func truncateRunes(value string, maxLen int) string {
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
|
||||
func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
mcpServerIDs := c.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
labels := map[string]string(c.Labels)
|
||||
if labels == nil {
|
||||
labels = map[string]string{}
|
||||
}
|
||||
chat := codersdk.Chat{
|
||||
ID: c.ID,
|
||||
OwnerID: c.OwnerID,
|
||||
LastModelConfigID: c.LastModelConfigID,
|
||||
Title: c.Title,
|
||||
Status: codersdk.ChatStatus(c.Status),
|
||||
Archived: c.Archived,
|
||||
CreatedAt: c.CreatedAt,
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
}
|
||||
if c.LastError.Valid {
|
||||
chat.LastError = &c.LastError.String
|
||||
}
|
||||
if c.ParentChatID.Valid {
|
||||
parentChatID := c.ParentChatID.UUID
|
||||
chat.ParentChatID = &parentChatID
|
||||
}
|
||||
switch {
|
||||
case c.RootChatID.Valid:
|
||||
rootChatID := c.RootChatID.UUID
|
||||
chat.RootChatID = &rootChatID
|
||||
case c.ParentChatID.Valid:
|
||||
rootChatID := c.ParentChatID.UUID
|
||||
chat.RootChatID = &rootChatID
|
||||
default:
|
||||
rootChatID := c.ID
|
||||
chat.RootChatID = &rootChatID
|
||||
}
|
||||
if c.WorkspaceID.Valid {
|
||||
chat.WorkspaceID = &c.WorkspaceID.UUID
|
||||
}
|
||||
if diffStatus != nil {
|
||||
convertedDiffStatus := db2sdk.ChatDiffStatus(c.ID, diffStatus)
|
||||
chat.DiffStatus = &convertedDiffStatus
|
||||
}
|
||||
return chat
|
||||
}
|
||||
|
||||
func convertChats(chats []database.Chat, diffStatusesByChatID map[uuid.UUID]database.ChatDiffStatus) []codersdk.Chat {
|
||||
result := make([]codersdk.Chat, len(chats))
|
||||
for i, c := range chats {
|
||||
diffStatus, ok := diffStatusesByChatID[c.ID]
|
||||
if ok {
|
||||
result[i] = convertChat(c, &diffStatus)
|
||||
continue
|
||||
}
|
||||
|
||||
result[i] = convertChat(c, nil)
|
||||
if diffStatusesByChatID != nil {
|
||||
emptyDiffStatus := db2sdk.ChatDiffStatus(c.ID, nil)
|
||||
result[i].DiffStatus = &emptyDiffStatus
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) codersdk.ChatCostModelBreakdown {
|
||||
displayName := strings.TrimSpace(model.DisplayName)
|
||||
if displayName == "" {
|
||||
@@ -3880,6 +4048,8 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil)
|
||||
|
||||
httpapi.Write(
|
||||
ctx,
|
||||
rw,
|
||||
@@ -3966,6 +4136,8 @@ func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil)
|
||||
|
||||
httpapi.Write(
|
||||
ctx,
|
||||
rw,
|
||||
@@ -4020,6 +4192,8 @@ func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil)
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@@ -4199,6 +4373,8 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventModelConfig, inserted.ID)
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, convertChatModelConfig(inserted))
|
||||
}
|
||||
|
||||
@@ -4370,6 +4546,8 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventModelConfig, updated.ID)
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertChatModelConfig(updated))
|
||||
}
|
||||
|
||||
@@ -4410,6 +4588,8 @@ func (api *API) deleteChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventModelConfig, modelConfigID)
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
|
||||
+1487
-165
File diff suppressed because it is too large
Load Diff
@@ -71,8 +71,8 @@ func (r *ProvisionerDaemonsReport) Run(ctx context.Context, opts *ProvisionerDae
|
||||
return
|
||||
}
|
||||
|
||||
// nolint: gocritic // need an actor to fetch provisioner daemons
|
||||
daemons, err := opts.Store.GetProvisionerDaemons(dbauthz.AsSystemRestricted(ctx))
|
||||
// nolint: gocritic // Read-only access to provisioner daemons for health check
|
||||
daemons, err := opts.Store.GetProvisionerDaemons(dbauthz.AsSystemReadProvisionerDaemons(ctx))
|
||||
if err != nil {
|
||||
r.Severity = health.SeverityError
|
||||
r.Error = ptr.Ref("error fetching provisioner daemons: " + err.Error())
|
||||
|
||||
@@ -438,7 +438,7 @@ func OneWayWebSocketEventSender(log slog.Logger) func(rw http.ResponseWriter, r
|
||||
}
|
||||
go HeartbeatClose(ctx, log, cancel, socket)
|
||||
|
||||
eventC := make(chan codersdk.ServerSentEvent)
|
||||
eventC := make(chan codersdk.ServerSentEvent, 64)
|
||||
socketErrC := make(chan websocket.CloseError, 1)
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
@@ -488,6 +488,16 @@ func OneWayWebSocketEventSender(log slog.Logger) func(rw http.ResponseWriter, r
|
||||
}()
|
||||
|
||||
sendEvent := func(event codersdk.ServerSentEvent) error {
|
||||
// Prioritize context cancellation over sending to the
|
||||
// buffered channel. Without this check, both cases in
|
||||
// the select below can fire simultaneously when the
|
||||
// context is already done and the channel has capacity,
|
||||
// making the result nondeterministic.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case eventC <- event:
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -699,8 +699,8 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
// is being used with the correct audience/resource server (RFC 8707).
|
||||
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, accessURL *url.URL, r *http.Request) error {
|
||||
// Get the OAuth2 provider app token to check its audience
|
||||
//nolint:gocritic // System needs to access token for audience validation
|
||||
token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemRestricted(ctx), key.ID)
|
||||
//nolint:gocritic // OAuth2 system context — audience validation for provider app tokens
|
||||
token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemOAuth2(ctx), key.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to get OAuth2 token: %w", err)
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ func TestChatParam(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(context.Background(), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: ownerID,
|
||||
WorkspaceID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
|
||||
@@ -73,7 +73,6 @@ func CSRF(cookieCfg codersdk.HTTPCookieConfig) func(next http.Handler) http.Hand
|
||||
|
||||
// CSRF only affects requests that automatically attach credentials via a cookie.
|
||||
// If no cookie is present, then there is no risk of CSRF.
|
||||
//nolint:govet
|
||||
sessCookie, err := r.Cookie(codersdk.SessionTokenCookie)
|
||||
if xerrors.Is(err, http.ErrNoCookie) {
|
||||
return true
|
||||
|
||||
+95
-10
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -59,7 +60,8 @@ func (api *API) listMCPServerConfigs(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Look up the calling user's OAuth2 tokens so we can populate
|
||||
// auth_connected per server.
|
||||
// auth_connected per server. Attempt to refresh expired tokens
|
||||
// so the status is accurate and the token is ready for use.
|
||||
//nolint:gocritic // Need to check user tokens across all servers.
|
||||
userTokens, err := api.Database.GetMCPServerUserTokensByUserID(dbauthz.AsSystemRestricted(ctx), apiKey.UserID)
|
||||
if err != nil {
|
||||
@@ -69,9 +71,20 @@ func (api *API) listMCPServerConfigs(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Build a config lookup for the refresh helper.
|
||||
configByID := make(map[uuid.UUID]database.MCPServerConfig, len(configs))
|
||||
for _, c := range configs {
|
||||
configByID[c.ID] = c
|
||||
}
|
||||
|
||||
tokenMap := make(map[uuid.UUID]bool, len(userTokens))
|
||||
for _, t := range userTokens {
|
||||
tokenMap[t.MCPServerConfigID] = true
|
||||
for _, tok := range userTokens {
|
||||
cfg, ok := configByID[tok.MCPServerConfigID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
tokenMap[tok.MCPServerConfigID] = api.refreshMCPUserToken(ctx, cfg, tok, apiKey.UserID)
|
||||
}
|
||||
|
||||
resp := make([]codersdk.MCPServerConfig, 0, len(configs))
|
||||
@@ -157,6 +170,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)),
|
||||
Availability: strings.TrimSpace(req.Availability),
|
||||
Enabled: req.Enabled,
|
||||
ModelIntent: req.ModelIntent,
|
||||
CreatedBy: apiKey.UserID,
|
||||
UpdatedBy: apiKey.UserID,
|
||||
})
|
||||
@@ -243,6 +257,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ToolDenyList: inserted.ToolDenyList,
|
||||
Availability: inserted.Availability,
|
||||
Enabled: inserted.Enabled,
|
||||
ModelIntent: inserted.ModelIntent,
|
||||
UpdatedBy: apiKey.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -310,6 +325,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)),
|
||||
Availability: strings.TrimSpace(req.Availability),
|
||||
Enabled: req.Enabled,
|
||||
ModelIntent: req.ModelIntent,
|
||||
CreatedBy: apiKey.UserID,
|
||||
UpdatedBy: apiKey.UserID,
|
||||
})
|
||||
@@ -386,7 +402,8 @@ func (api *API) getMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
sdkConfig = convertMCPServerConfigRedacted(config)
|
||||
}
|
||||
|
||||
// Populate AuthConnected for the calling user.
|
||||
// Populate AuthConnected for the calling user. Attempt to
|
||||
// refresh the token so the status is accurate.
|
||||
if config.AuthType == "oauth2" {
|
||||
//nolint:gocritic // Need to check user token for this server.
|
||||
userTokens, err := api.Database.GetMCPServerUserTokensByUserID(dbauthz.AsSystemRestricted(ctx), apiKey.UserID)
|
||||
@@ -397,9 +414,9 @@ func (api *API) getMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
for _, t := range userTokens {
|
||||
if t.MCPServerConfigID == config.ID {
|
||||
sdkConfig.AuthConnected = true
|
||||
for _, tok := range userTokens {
|
||||
if tok.MCPServerConfigID == config.ID {
|
||||
sdkConfig.AuthConnected = api.refreshMCPUserToken(ctx, config, tok, apiKey.UserID)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -558,6 +575,11 @@ func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
enabled = *req.Enabled
|
||||
}
|
||||
|
||||
modelIntent := existing.ModelIntent
|
||||
if req.ModelIntent != nil {
|
||||
modelIntent = *req.ModelIntent
|
||||
}
|
||||
|
||||
// When auth_type changes, clear fields belonging to the
|
||||
// previous auth type so stale secrets don't persist.
|
||||
if authType != existing.AuthType {
|
||||
@@ -625,6 +647,7 @@ func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ToolDenyList: toolDenyList,
|
||||
Availability: availability,
|
||||
Enabled: enabled,
|
||||
ModelIntent: modelIntent,
|
||||
UpdatedBy: apiKey.UserID,
|
||||
ID: existing.ID,
|
||||
})
|
||||
@@ -1002,6 +1025,67 @@ func (api *API) mcpServerOAuth2Disconnect(rw http.ResponseWriter, r *http.Reques
|
||||
|
||||
// parseMCPServerConfigID extracts the MCP server config UUID from the
|
||||
// "mcpServer" path parameter.
|
||||
// refreshMCPUserToken attempts to refresh an expired OAuth2 token
|
||||
// for the given MCP server config. Returns true when the token is
|
||||
// valid (either still fresh or successfully refreshed), false when
|
||||
// the token is expired and cannot be refreshed.
|
||||
func (api *API) refreshMCPUserToken(
|
||||
ctx context.Context,
|
||||
cfg database.MCPServerConfig,
|
||||
tok database.MCPServerUserToken,
|
||||
userID uuid.UUID,
|
||||
) bool {
|
||||
if cfg.AuthType != "oauth2" {
|
||||
return true
|
||||
}
|
||||
if tok.RefreshToken == "" {
|
||||
// No refresh token — consider connected only if not
|
||||
// expired (or no expiry set).
|
||||
return !tok.Expiry.Valid || tok.Expiry.Time.After(time.Now())
|
||||
}
|
||||
|
||||
result, err := mcpclient.RefreshOAuth2Token(ctx, cfg, tok)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to refresh MCP oauth2 token",
|
||||
slog.F("server_slug", cfg.Slug),
|
||||
slog.Error(err),
|
||||
)
|
||||
// Refresh failed — token is dead.
|
||||
return false
|
||||
}
|
||||
|
||||
if result.Refreshed {
|
||||
var expiry sql.NullTime
|
||||
if !result.Expiry.IsZero() {
|
||||
expiry = sql.NullTime{Time: result.Expiry, Valid: true}
|
||||
}
|
||||
|
||||
//nolint:gocritic // Need system-level write access to
|
||||
// persist the refreshed OAuth2 token.
|
||||
_, err = api.Database.UpsertMCPServerUserToken(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
database.UpsertMCPServerUserTokenParams{
|
||||
MCPServerConfigID: tok.MCPServerConfigID,
|
||||
UserID: userID,
|
||||
AccessToken: result.AccessToken,
|
||||
AccessTokenKeyID: sql.NullString{},
|
||||
RefreshToken: result.RefreshToken,
|
||||
RefreshTokenKeyID: sql.NullString{},
|
||||
TokenType: result.TokenType,
|
||||
Expiry: expiry,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to persist refreshed MCP oauth2 token",
|
||||
slog.F("server_slug", cfg.Slug),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func parseMCPServerConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
||||
mcpServerID, err := uuid.Parse(chi.URLParam(r, "mcpServer"))
|
||||
if err != nil {
|
||||
@@ -1045,9 +1129,10 @@ func convertMCPServerConfig(config database.MCPServerConfig) codersdk.MCPServerC
|
||||
|
||||
Availability: config.Availability,
|
||||
|
||||
Enabled: config.Enabled,
|
||||
CreatedAt: config.CreatedAt,
|
||||
UpdatedAt: config.UpdatedAt,
|
||||
Enabled: config.Enabled,
|
||||
ModelIntent: config.ModelIntent,
|
||||
CreatedAt: config.CreatedAt,
|
||||
UpdatedAt: config.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+62
-4
@@ -2,6 +2,7 @@ package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -179,7 +180,17 @@ func (api *API) organizationMember(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := convertOrganizationMembersWithUserData(ctx, api.Database, rows)
|
||||
var aiSeatSet map[uuid.UUID]struct{}
|
||||
if api.Entitlements.Enabled(codersdk.FeatureAIGovernanceUserLimit) {
|
||||
//nolint:gocritic // AI seat state is a system-level read gated by entitlement.
|
||||
aiSeatSet, err = getAISeatSetByUserIDs(dbauthz.AsSystemRestricted(ctx), api.Database, []uuid.UUID{member.UserID})
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := convertOrganizationMembersWithUserData(ctx, api.Database, rows, aiSeatSet)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
@@ -227,7 +238,21 @@ func (api *API) listMembers(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := convertOrganizationMembersWithUserData(ctx, api.Database, members)
|
||||
userIDs := make([]uuid.UUID, 0, len(members))
|
||||
for _, member := range members {
|
||||
userIDs = append(userIDs, member.OrganizationMember.UserID)
|
||||
}
|
||||
var aiSeatSet map[uuid.UUID]struct{}
|
||||
if api.Entitlements.Enabled(codersdk.FeatureAIGovernanceUserLimit) {
|
||||
//nolint:gocritic // AI seat state is a system-level read gated by entitlement.
|
||||
aiSeatSet, err = getAISeatSetByUserIDs(dbauthz.AsSystemRestricted(ctx), api.Database, userIDs)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := convertOrganizationMembersWithUserData(ctx, api.Database, members, aiSeatSet)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
@@ -324,7 +349,21 @@ func (api *API) paginatedMembers(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
members, err := convertOrganizationMembersWithUserData(ctx, api.Database, memberRows)
|
||||
userIDs := make([]uuid.UUID, 0, len(memberRows))
|
||||
for _, member := range memberRows {
|
||||
userIDs = append(userIDs, member.OrganizationMember.UserID)
|
||||
}
|
||||
var aiSeatSet map[uuid.UUID]struct{}
|
||||
if api.Entitlements.Enabled(codersdk.FeatureAIGovernanceUserLimit) {
|
||||
//nolint:gocritic // AI seat state is a system-level read gated by entitlement.
|
||||
aiSeatSet, err = getAISeatSetByUserIDs(dbauthz.AsSystemRestricted(ctx), api.Database, userIDs)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
members, err := convertOrganizationMembersWithUserData(ctx, api.Database, memberRows, aiSeatSet)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
@@ -337,6 +376,23 @@ func (api *API) paginatedMembers(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func getAISeatSetByUserIDs(ctx context.Context, db database.Store, userIDs []uuid.UUID) (map[uuid.UUID]struct{}, error) {
|
||||
aiSeatUserIDs, err := db.GetUserAISeatStates(ctx, userIDs)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aiSeatSet := make(map[uuid.UUID]struct{}, len(aiSeatUserIDs))
|
||||
for _, uid := range aiSeatUserIDs {
|
||||
aiSeatSet[uid] = struct{}{}
|
||||
}
|
||||
|
||||
return aiSeatSet, nil
|
||||
}
|
||||
|
||||
// @Summary Assign role to organization member
|
||||
// @ID assign-role-to-organization-member
|
||||
// @Security CoderSessionToken
|
||||
@@ -508,7 +564,7 @@ func convertOrganizationMembers(ctx context.Context, db database.Store, mems []d
|
||||
return converted, nil
|
||||
}
|
||||
|
||||
func convertOrganizationMembersWithUserData(ctx context.Context, db database.Store, rows []database.OrganizationMembersRow) ([]codersdk.OrganizationMemberWithUserData, error) {
|
||||
func convertOrganizationMembersWithUserData(ctx context.Context, db database.Store, rows []database.OrganizationMembersRow, aiSeatSet map[uuid.UUID]struct{}) ([]codersdk.OrganizationMemberWithUserData, error) {
|
||||
members := make([]database.OrganizationMember, 0)
|
||||
for _, row := range rows {
|
||||
members = append(members, row.OrganizationMember)
|
||||
@@ -524,12 +580,14 @@ func convertOrganizationMembersWithUserData(ctx context.Context, db database.Sto
|
||||
|
||||
converted := make([]codersdk.OrganizationMemberWithUserData, 0)
|
||||
for i := range convertedMembers {
|
||||
_, hasAISeat := aiSeatSet[rows[i].OrganizationMember.UserID]
|
||||
converted = append(converted, codersdk.OrganizationMemberWithUserData{
|
||||
Username: rows[i].Username,
|
||||
AvatarURL: rows[i].AvatarURL,
|
||||
Name: rows[i].Name,
|
||||
Email: rows[i].Email,
|
||||
GlobalRoles: db2sdk.SlimRolesFromNames(rows[i].GlobalRoles),
|
||||
HasAISeat: hasAISeat,
|
||||
LastSeenAt: rows[i].LastSeenAt,
|
||||
Status: codersdk.UserStatus(rows[i].Status),
|
||||
IsServiceAccount: rows[i].IsServiceAccount,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user