Compare commits
92 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f92b50dc43 | |||
| 3db758c6de | |||
| cb0b84a2d3 | |||
| 982739f3bf | |||
| 7b02a51841 | |||
| bd467ce443 | |||
| c67c93982b | |||
| 2f52de7cfc | |||
| 0552b927b2 | |||
| 16b1b6865d | |||
| 897533f08d | |||
| 3e25cc9238 | |||
| bb64cab8a5 | |||
| b149433138 | |||
| 8dff1cbc57 | |||
| a62ead8588 | |||
| b68c14dd04 | |||
| 508114d484 | |||
| e0fbb0e4ec | |||
| 7bde763b66 | |||
| 36141fafad | |||
| 3462c31f43 | |||
| a0ea71b74c | |||
| 0a14bb529e | |||
| 2c32d84f12 | |||
| 76d89f59af | |||
| 1a3a92bd1b | |||
| 4018320614 | |||
| d9700baa8d | |||
| 82456ff62e | |||
| 83fd4cf5c2 | |||
| 38d4da82b9 | |||
| 19e0e0e8e6 | |||
| 1d0653cdab | |||
| 95cff8c5fb | |||
| ad2415ede7 | |||
| 1e40cea199 | |||
| 9d6557d173 | |||
| 224db483d7 | |||
| 8237822441 | |||
| 65bf7c3b18 | |||
| 76cbc580f0 | |||
| 391b22aef7 | |||
| f8e8f979a2 | |||
| fb0ed1162b | |||
| 3f519744aa | |||
| 2505f6245f | |||
| 29ad2c6201 | |||
| 27e5ff0a8e | |||
| 128a7c23e6 | |||
| efb19eb748 | |||
| 2c499484b7 | |||
| 33d9d0d875 | |||
| f219834f5c | |||
| 7a94a683c4 | |||
| 2e6fdf2344 | |||
| 3d139c1a24 | |||
| f957981c8b | |||
| 584c61acb5 | |||
| f95a5202bf | |||
| d954460380 | |||
| f4240bb8c1 | |||
| 7caef4987f | |||
| 9b91af8ab7 | |||
| 506fba9ebf | |||
| 461a31e5d8 | |||
| e3a0dcd6fc | |||
| 12ada0115f | |||
| 7b0421d8c6 | |||
| 477d6d0cde | |||
| de61ac529d | |||
| 7f496c2f18 | |||
| 590235138f | |||
| 543c448b72 | |||
| 35c26ce22a | |||
| c2592c9f12 | |||
| b969d66978 | |||
| 1f808cdc62 | |||
| 497f637f58 | |||
| be686a8d0d | |||
| 7b7baea851 | |||
| a3de0fc78d | |||
| ab77154975 | |||
| c5d720f73d | |||
| 983819860f | |||
| f820945d9f | |||
| da5395a8ae | |||
| 86b919e4f7 | |||
| 233343c010 | |||
| 3a612898c6 | |||
| 3f7a3e3354 | |||
| 17a71aea72 |
@@ -18,35 +18,35 @@ The 5.x era resolves years of module system ambiguity and cleans house on legacy
|
||||
|
||||
The left column reflects patterns still common before TypeScript 5.x. Write the right column instead. The "Since" column tells you the minimum TypeScript version required.
|
||||
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}() | [\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}()\[\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
|
||||
## New capabilities
|
||||
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
enabled: true
|
||||
preservePullRequestTitle: true
|
||||
@@ -91,12 +91,6 @@ updates:
|
||||
emotion:
|
||||
patterns:
|
||||
- "@emotion*"
|
||||
exclude-patterns:
|
||||
- "jest-runner-eslint"
|
||||
jest:
|
||||
patterns:
|
||||
- "jest"
|
||||
- "@types/jest"
|
||||
vite:
|
||||
patterns:
|
||||
- "vite*"
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
# Automatically backport merged PRs to the last N release branches when the
|
||||
# "backport" label is applied. Works whether the label is added before or
|
||||
# after the PR is merged.
|
||||
#
|
||||
# Usage:
|
||||
# 1. Add the "backport" label to a PR targeting main.
|
||||
# 2. When the PR merges (or if already merged), the workflow detects the
|
||||
# latest release/* branches and opens one cherry-pick PR per branch.
|
||||
#
|
||||
# The created backport PRs follow existing repo conventions:
|
||||
# - Branch: backport/<pr>-to-<version>
|
||||
# - Title: <original PR title> (#<pr>)
|
||||
# - Body: links back to the original PR and merge commit
|
||||
|
||||
name: Backport
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- closed
|
||||
- labeled
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
|
||||
# fire in quick succession.
|
||||
concurrency:
|
||||
group: backport-${{ github.event.pull_request.number }}
|
||||
|
||||
jobs:
|
||||
detect:
|
||||
name: Detect target branches
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(github.event.pull_request.labels.*.name, 'backport')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
branches: ${{ steps.find.outputs.branches }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Need all refs to discover release branches.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Find latest release branches
|
||||
id: find
|
||||
run: |
|
||||
# List remote release branches matching the exact release/2.X
|
||||
# pattern (no suffixes like release/2.31_hotfix), sort by minor
|
||||
# version descending, and take the top 3.
|
||||
BRANCHES=$(
|
||||
git branch -r \
|
||||
| grep -E '^\s*origin/release/2\.[0-9]+$' \
|
||||
| sed 's|.*origin/||' \
|
||||
| sort -t. -k2 -n -r \
|
||||
| head -3
|
||||
)
|
||||
|
||||
if [ -z "$BRANCHES" ]; then
|
||||
echo "No release branches found."
|
||||
echo "branches=[]" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Convert to JSON array for the matrix.
|
||||
JSON=$(echo "$BRANCHES" | jq -Rnc '[inputs | select(length > 0)]')
|
||||
echo "branches=$JSON" >> "$GITHUB_OUTPUT"
|
||||
echo "Will backport to: $JSON"
|
||||
|
||||
backport:
|
||||
name: "Backport to ${{ matrix.branch }}"
|
||||
needs: detect
|
||||
if: needs.detect.outputs.branches != '[]'
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
branch: ${{ fromJson(needs.detect.outputs.branches) }}
|
||||
fail-fast: false
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Full history required for cherry-pick.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cherry-pick and open PR
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
RELEASE_VERSION="${{ matrix.branch }}"
|
||||
# Strip the release/ prefix for naming.
|
||||
VERSION="${RELEASE_VERSION#release/}"
|
||||
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Check if backport branch already exists (idempotency for re-runs).
|
||||
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Backport branch ${BACKPORT_BRANCH} already exists, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Create the backport branch from the target release branch.
|
||||
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_VERSION}"
|
||||
|
||||
# Cherry-pick the merge commit. Use -x to record provenance and
|
||||
# -m1 to pick the first parent (the main branch side).
|
||||
CONFLICTS=false
|
||||
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
|
||||
echo "::warning::Cherry-pick to ${RELEASE_VERSION} had conflicts."
|
||||
CONFLICTS=true
|
||||
|
||||
# Abort the failed cherry-pick and create an empty commit
|
||||
# explaining the situation.
|
||||
git cherry-pick --abort
|
||||
git commit --allow-empty -m "Cherry-pick of #${PR_NUMBER} requires manual resolution
|
||||
|
||||
The automatic cherry-pick of ${MERGE_SHA} to ${RELEASE_VERSION} had conflicts.
|
||||
Please cherry-pick manually:
|
||||
|
||||
git cherry-pick -x -m1 ${MERGE_SHA}"
|
||||
fi
|
||||
|
||||
git push origin "$BACKPORT_BRANCH"
|
||||
|
||||
TITLE="${PR_TITLE} (#${PR_NUMBER})"
|
||||
BODY=$(cat <<EOF
|
||||
Backport of ${PR_URL}
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
if [ "$CONFLICTS" = true ]; then
|
||||
TITLE="${TITLE} (conflicts)"
|
||||
BODY="${BODY}
|
||||
|
||||
> [!WARNING]
|
||||
> The automatic cherry-pick had conflicts.
|
||||
> Please resolve manually by cherry-picking the original merge commit:
|
||||
>
|
||||
> \`\`\`
|
||||
> git fetch origin ${BACKPORT_BRANCH}
|
||||
> git checkout ${BACKPORT_BRANCH}
|
||||
> git reset --hard origin/${RELEASE_VERSION}
|
||||
> git cherry-pick -x -m1 ${MERGE_SHA}
|
||||
> # resolve conflicts, then push
|
||||
> \`\`\`"
|
||||
fi
|
||||
|
||||
# Check if a PR already exists for this branch (idempotency
|
||||
# for re-runs).
|
||||
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_VERSION" --state all --json number --jq '.[0].number // empty')
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
gh pr create \
|
||||
--base "$RELEASE_VERSION" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
@@ -0,0 +1,152 @@
|
||||
# Automatically cherry-pick merged PRs to the latest release branch when the
|
||||
# "cherry-pick" label is applied. Works whether the label is added before or
|
||||
# after the PR is merged.
|
||||
#
|
||||
# Usage:
|
||||
# 1. Add the "cherry-pick" label to a PR targeting main.
|
||||
# 2. When the PR merges (or if already merged), the workflow detects the
|
||||
# latest release/* branch and opens a cherry-pick PR against it.
|
||||
#
|
||||
# The created PRs follow existing repo conventions:
|
||||
# - Branch: backport/<pr>-to-<version>
|
||||
# - Title: <original PR title> (#<pr>)
|
||||
# - Body: links back to the original PR and merge commit
|
||||
|
||||
name: Cherry-pick to release
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- closed
|
||||
- labeled
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
|
||||
# fire in quick succession.
|
||||
concurrency:
|
||||
group: cherry-pick-${{ github.event.pull_request.number }}
|
||||
|
||||
jobs:
|
||||
cherry-pick:
|
||||
name: Cherry-pick to latest release
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(github.event.pull_request.labels.*.name, 'cherry-pick')
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Full history required for cherry-pick and branch discovery.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cherry-pick and open PR
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Find the latest release branch matching the exact release/2.X
|
||||
# pattern (no suffixes like release/2.31_hotfix).
|
||||
RELEASE_BRANCH=$(
|
||||
git branch -r \
|
||||
| grep -E '^\s*origin/release/2\.[0-9]+$' \
|
||||
| sed 's|.*origin/||' \
|
||||
| sort -t. -k2 -n -r \
|
||||
| head -1
|
||||
)
|
||||
|
||||
if [ -z "$RELEASE_BRANCH" ]; then
|
||||
echo "::error::No release branch found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Strip the release/ prefix for naming.
|
||||
VERSION="${RELEASE_BRANCH#release/}"
|
||||
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
|
||||
|
||||
echo "Target branch: $RELEASE_BRANCH"
|
||||
echo "Backport branch: $BACKPORT_BRANCH"
|
||||
|
||||
# Check if backport branch already exists (idempotency for re-runs).
|
||||
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Branch ${BACKPORT_BRANCH} already exists, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Create the backport branch from the target release branch.
|
||||
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_BRANCH}"
|
||||
|
||||
# Cherry-pick the merge commit. Use -x to record provenance and
|
||||
# -m1 to pick the first parent (the main branch side).
|
||||
CONFLICT=false
|
||||
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
|
||||
CONFLICT=true
|
||||
echo "::warning::Cherry-pick to ${RELEASE_BRANCH} had conflicts."
|
||||
|
||||
# Abort the failed cherry-pick and create an empty commit with
|
||||
# instructions so the PR can still be opened.
|
||||
git cherry-pick --abort
|
||||
git commit --allow-empty -m "cherry-pick of #${PR_NUMBER} failed — resolve conflicts manually
|
||||
|
||||
Cherry-pick of ${MERGE_SHA} onto ${RELEASE_BRANCH} had conflicts.
|
||||
To resolve:
|
||||
git fetch origin ${BACKPORT_BRANCH}
|
||||
git checkout ${BACKPORT_BRANCH}
|
||||
git cherry-pick -x -m1 ${MERGE_SHA}
|
||||
# resolve conflicts
|
||||
git push origin ${BACKPORT_BRANCH}"
|
||||
fi
|
||||
|
||||
git push origin "$BACKPORT_BRANCH"
|
||||
|
||||
BODY=$(cat <<EOF
|
||||
Cherry-pick of ${PR_URL}
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
TITLE="${PR_TITLE} (#${PR_NUMBER})"
|
||||
if [ "$CONFLICT" = true ]; then
|
||||
TITLE="[CONFLICT] ${TITLE}"
|
||||
fi
|
||||
|
||||
# Check if a PR already exists for this branch (idempotency
|
||||
# for re-runs). Use --state all to catch closed/merged PRs too.
|
||||
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_BRANCH" --state all --json number --jq '.[0].number // empty')
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
NEW_PR_URL=$(
|
||||
gh pr create \
|
||||
--base "$RELEASE_BRANCH" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
)
|
||||
|
||||
# Comment on the original PR to notify the author.
|
||||
COMMENT="Cherry-pick PR created: ${NEW_PR_URL}"
|
||||
if [ "$CONFLICT" = true ]; then
|
||||
COMMENT="${COMMENT} (⚠️ conflicts need manual resolution)"
|
||||
fi
|
||||
gh pr comment "$PR_NUMBER" --body "$COMMENT"
|
||||
@@ -0,0 +1,93 @@
|
||||
# Ensures that only bug fixes are cherry-picked to release branches.
|
||||
# PRs targeting release/* must have a title starting with "fix:" or "fix(scope):".
|
||||
name: PR Cherry-Pick Check
|
||||
|
||||
on:
|
||||
# zizmor: ignore[dangerous-triggers] Only reads PR metadata and comments; does not checkout PR code.
|
||||
pull_request_target:
|
||||
types: [opened, reopened, edited]
|
||||
branches:
|
||||
- "release/*"
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-cherry-pick:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Check PR title for bug fix
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
const baseBranch = context.payload.pull_request.base.ref;
|
||||
const author = context.payload.pull_request.user.login;
|
||||
|
||||
console.log(`PR #${prNumber}: "${title}" -> ${baseBranch}`);
|
||||
|
||||
// Match conventional commit "fix:" or "fix(scope):" prefix.
|
||||
const isBugFix = /^fix(\(.+\))?:/.test(title);
|
||||
|
||||
if (isBugFix) {
|
||||
console.log("PR title indicates a bug fix. No action needed.");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("PR title does not indicate a bug fix. Commenting.");
|
||||
|
||||
// Check for an existing comment from this bot to avoid duplicates
|
||||
// on title edits.
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const marker = "<!-- cherry-pick-check -->";
|
||||
const existingComment = comments.find(
|
||||
(c) => c.body && c.body.includes(marker),
|
||||
);
|
||||
|
||||
const body = [
|
||||
marker,
|
||||
`👋 Hey @${author}!`,
|
||||
"",
|
||||
`This PR is targeting the \`${baseBranch}\` release branch, but its title does not start with \`fix:\` or \`fix(scope):\`.`,
|
||||
"",
|
||||
"Only **bug fixes** should be cherry-picked to release branches. If this is a bug fix, please update the PR title to match the conventional commit format:",
|
||||
"",
|
||||
"```",
|
||||
"fix: description of the bug fix",
|
||||
"fix(scope): description of the bug fix",
|
||||
"```",
|
||||
"",
|
||||
"If this is **not** a bug fix, it likely should not target a release branch.",
|
||||
].join("\n");
|
||||
|
||||
if (existingComment) {
|
||||
console.log(`Updating existing comment ${existingComment.id}.`);
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: existingComment.id,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body,
|
||||
});
|
||||
}
|
||||
|
||||
core.warning(
|
||||
`PR #${prNumber} targets ${baseBranch} but is not a bug fix. Title must start with "fix:" or "fix(scope):".`,
|
||||
);
|
||||
@@ -36,6 +36,7 @@ typ = "typ"
|
||||
styl = "styl"
|
||||
edn = "edn"
|
||||
Inferrable = "Inferrable"
|
||||
IIF = "IIF"
|
||||
|
||||
[files]
|
||||
extend-exclude = [
|
||||
|
||||
@@ -103,3 +103,6 @@ PLAN.md
|
||||
|
||||
# Ignore any dev licenses
|
||||
license.txt
|
||||
-e
|
||||
# Agent planning documents (local working files).
|
||||
docs/plans/
|
||||
|
||||
@@ -91,6 +91,59 @@ define atomic_write
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
endef
|
||||
|
||||
# Helper binary targets. Built with go build -o to avoid caching
|
||||
# link-stage executables in GOCACHE. Each binary is a real Make
|
||||
# target so parallel -j builds serialize correctly instead of
|
||||
# racing on the same output path.
|
||||
|
||||
_gen/bin/apitypings: $(wildcard scripts/apitypings/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apitypings
|
||||
|
||||
_gen/bin/auditdocgen: $(wildcard scripts/auditdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/auditdocgen
|
||||
|
||||
_gen/bin/check-scopes: $(wildcard scripts/check-scopes/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/check-scopes
|
||||
|
||||
_gen/bin/clidocgen: $(wildcard scripts/clidocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/clidocgen
|
||||
|
||||
_gen/bin/dbdump: $(wildcard coderd/database/gen/dump/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./coderd/database/gen/dump
|
||||
|
||||
_gen/bin/examplegen: $(wildcard scripts/examplegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/examplegen
|
||||
|
||||
_gen/bin/gensite: $(wildcard scripts/gensite/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/gensite
|
||||
|
||||
_gen/bin/apikeyscopesgen: $(wildcard scripts/apikeyscopesgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apikeyscopesgen
|
||||
|
||||
_gen/bin/metricsdocgen: $(wildcard scripts/metricsdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen
|
||||
|
||||
_gen/bin/metricsdocgen-scanner: $(wildcard scripts/metricsdocgen/scanner/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen/scanner
|
||||
|
||||
_gen/bin/modeloptionsgen: $(wildcard scripts/modeloptionsgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/modeloptionsgen
|
||||
|
||||
_gen/bin/typegen: $(wildcard scripts/typegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/typegen
|
||||
|
||||
# Shared temp directory for atomic writes. Lives at the project root
|
||||
# so all targets share the same filesystem, and is gitignored.
|
||||
# Order-only prerequisite: recipes that need it depend on | _gen
|
||||
@@ -201,6 +254,7 @@ endif
|
||||
|
||||
clean:
|
||||
rm -rf build/ site/build/ site/out/
|
||||
rm -rf _gen/bin
|
||||
mkdir -p build/
|
||||
git restore site/out/
|
||||
.PHONY: clean
|
||||
@@ -654,8 +708,8 @@ lint/go:
|
||||
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
|
||||
.PHONY: lint/go
|
||||
|
||||
lint/examples:
|
||||
go run ./scripts/examplegen/main.go -lint
|
||||
lint/examples: | _gen/bin/examplegen
|
||||
_gen/bin/examplegen -lint
|
||||
.PHONY: lint/examples
|
||||
|
||||
# Use shfmt to determine the shell files, takes editorconfig into consideration.
|
||||
@@ -693,8 +747,8 @@ lint/actions/zizmor:
|
||||
.PHONY: lint/actions/zizmor
|
||||
|
||||
# Verify api_key_scope enum contains all RBAC <resource>:<action> values.
|
||||
lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
lint/check-scopes: coderd/database/dump.sql | _gen/bin/check-scopes
|
||||
_gen/bin/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# Verify migrations do not hardcode the public schema.
|
||||
@@ -734,8 +788,8 @@ lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
|
||||
#
|
||||
# pre-commit uses two phases: gen+fmt first, then lint+build. This
|
||||
# avoids races where gen's `go run` creates temporary .go files that
|
||||
# lint's find-based checks pick up. Within each phase, targets run in
|
||||
# avoids races where gen creates temporary .go files that lint's
|
||||
# find-based checks pick up. Within each phase, targets run in
|
||||
# parallel via -j. It fails if any tracked files have unstaged
|
||||
# changes afterward.
|
||||
|
||||
@@ -949,8 +1003,8 @@ gen/mark-fresh:
|
||||
|
||||
# Runs migrations to output a dump of the database schema after migrations are
|
||||
# applied.
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql)
|
||||
go run ./coderd/database/gen/dump/main.go
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql) | _gen/bin/dbdump
|
||||
_gen/bin/dbdump
|
||||
touch "$@"
|
||||
|
||||
# Generates Go code for querying the database.
|
||||
@@ -1067,88 +1121,88 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
|
||||
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen _gen/bin/apitypings
|
||||
$(call atomic_write,_gen/bin/apitypings,./scripts/biome_format.sh)
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
(cd site/ && pnpm run gen:provisioner)
|
||||
touch "$@"
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen _gen/bin/gensite
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
go run ./scripts/gensite/ -icons "$$tmpfile" && \
|
||||
_gen/bin/gensite -icons "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
|
||||
$(call atomic_write,go run ./scripts/examplegen/main.go)
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen _gen/bin/examplegen
|
||||
$(call atomic_write,_gen/bin/examplegen)
|
||||
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac object)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go because `go run` compiles
|
||||
# coderd/rbac which includes it.
|
||||
# NOTE: depends on object_gen.go because the generator build
|
||||
# compiles coderd/rbac which includes it.
|
||||
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go | _gen
|
||||
coderd/rbac/object_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file first to avoid truncating the package
|
||||
# during build since the generator imports the rbac package.
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
|
||||
$(call atomic_write,_gen/bin/typegen rbac scopenames)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file to avoid truncating the target, which
|
||||
# would break the codersdk package and any parallel build targets.
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
|
||||
$(call atomic_write,_gen/bin/typegen rbac codersdk)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/apikeyscopesgen
|
||||
# Generate SDK constants for external API key scopes.
|
||||
$(call atomic_write,go run ./scripts/apikeyscopesgen)
|
||||
$(call atomic_write,_gen/bin/apikeyscopesgen)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac typescript,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen countries,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
|
||||
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen _gen/bin/modeloptionsgen
|
||||
$(call atomic_write,_gen/bin/modeloptionsgen | tail -n +2,./scripts/biome_format.sh)
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
|
||||
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen _gen/bin/metricsdocgen-scanner
|
||||
$(call atomic_write,_gen/bin/metricsdocgen-scanner)
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen _gen/bin/metricsdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
|
||||
_gen/bin/metricsdocgen --prometheus-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen _gen/bin/clidocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && \
|
||||
mkdir -p "$$tmpdir/docs/reference/cli" && \
|
||||
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" _gen/bin/clidocgen && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
|
||||
rm -rf "$$tmpdir"
|
||||
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen _gen/bin/auditdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
|
||||
_gen/bin/auditdocgen --audit-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
<!-- markdownlint-disable MD041 -->
|
||||
<div align="center">
|
||||
<a href="https://coder.com#gh-light-mode-only">
|
||||
<img src="./docs/images/logo-black.png" alt="Coder Logo Light" style="width: 128px">
|
||||
<a href="https://coder.com#gh-light-mode-only" rel="nofollow">
|
||||
<img src="./docs/images/logo-black.png" alt="Coder Logo Light" style="width: 128px; max-width: 100%;">
|
||||
</a>
|
||||
<a href="https://coder.com#gh-dark-mode-only">
|
||||
<img src="./docs/images/logo-white.png" alt="Coder Logo Dark" style="width: 128px">
|
||||
<a href="https://coder.com#gh-dark-mode-only" rel="nofollow">
|
||||
<img src="./docs/images/logo-white.png" alt="Coder Logo Dark" style="width: 128px; max-width: 100%;">
|
||||
</a>
|
||||
|
||||
<h1>
|
||||
|
||||
+14
-6
@@ -102,6 +102,8 @@ type Options struct {
|
||||
ReportMetadataInterval time.Duration
|
||||
ServiceBannerRefreshInterval time.Duration
|
||||
BlockFileTransfer bool
|
||||
BlockReversePortForwarding bool
|
||||
BlockLocalPortForwarding bool
|
||||
Execer agentexec.Execer
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
@@ -214,6 +216,8 @@ func New(options Options) Agent {
|
||||
subsystems: options.Subsystems,
|
||||
logSender: agentsdk.NewLogSender(options.Logger),
|
||||
blockFileTransfer: options.BlockFileTransfer,
|
||||
blockReversePortForwarding: options.BlockReversePortForwarding,
|
||||
blockLocalPortForwarding: options.BlockLocalPortForwarding,
|
||||
|
||||
prometheusRegistry: prometheusRegistry,
|
||||
metrics: newAgentMetrics(prometheusRegistry),
|
||||
@@ -280,6 +284,8 @@ type agent struct {
|
||||
sshServer *agentssh.Server
|
||||
sshMaxTimeout time.Duration
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
|
||||
lifecycleUpdate chan struct{}
|
||||
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
|
||||
@@ -331,12 +337,14 @@ func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
func (a *agent) init() {
|
||||
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
|
||||
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
BlockReversePortForwarding: a.blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: a.blockLocalPortForwarding,
|
||||
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
|
||||
var connectionType proto.Connection_Type
|
||||
switch magicType {
|
||||
|
||||
@@ -986,6 +986,161 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_TCPLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_TCPRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
_, err = sshClient.ListenTCP(addr)
|
||||
require.ErrorContains(t, err, "tcpip-forward request denied by peer")
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
l, err := net.Listen("unix", remoteSocketPath)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("unix", remoteSocketPath)
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_UnixRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.ListenUnix(remoteSocketPath)
|
||||
require.ErrorContains(t, err, "streamlocal-forward@openssh.com request denied by peer")
|
||||
}
|
||||
|
||||
// TestAgent_LocalBlockedDoesNotAffectReverse verifies that blocking
|
||||
// local port forwarding does not prevent reverse port forwarding from
|
||||
// working. A field-name transposition at any plumbing hop would cause
|
||||
// both directions to be blocked when only one flag is set.
|
||||
func TestAgent_LocalBlockedDoesNotAffectReverse(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Reverse forwarding must still work.
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
var ll net.Listener
|
||||
for {
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
ll, err = sshClient.ListenTCP(addr)
|
||||
if err != nil {
|
||||
t.Logf("error remote forwarding: %s", err.Error())
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out getting random listener")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
_ = ll.Close()
|
||||
}
|
||||
|
||||
// TestAgent_ReverseBlockedDoesNotAffectLocal verifies that blocking
|
||||
// reverse port forwarding does not prevent local port forwarding from
|
||||
// working.
|
||||
func TestAgent_ReverseBlockedDoesNotAffectLocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
go echoOnce(t, rl)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Local forwarding must still work.
|
||||
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
|
||||
@@ -2862,6 +2862,126 @@ func TestAPI(t *testing.T) {
|
||||
"rebuilt agent should include updated display apps")
|
||||
})
|
||||
|
||||
// Verify that when a terraform-managed subagent is injected into
|
||||
// a devcontainer, the Directory field sent to Create reflects
|
||||
// the container-internal workspaceFolder from devcontainer
|
||||
// read-configuration, not the host-side workspace_folder from
|
||||
// the terraform resource. This is the scenario described in
|
||||
// https://linear.app/codercom/issue/PRODUCT-259:
|
||||
// 1. Non-terraform subagent → directory = /workspaces/foo (correct)
|
||||
// 2. Terraform subagent → directory was stuck on host path (bug)
|
||||
t.Run("TerraformDefinedSubAgentUsesContainerInternalDirectory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Dev Container tests are not supported on Windows (this test uses mocks but fails due to Windows paths)")
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitMedium)
|
||||
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
mCtrl = gomock.NewController(t)
|
||||
|
||||
terraformAgentID = uuid.New()
|
||||
containerID = "test-container-id"
|
||||
|
||||
// Given: A container with a host-side workspace folder.
|
||||
terraformContainer = codersdk.WorkspaceAgentContainer{
|
||||
ID: containerID,
|
||||
FriendlyName: "test-container",
|
||||
Image: "test-image",
|
||||
Running: true,
|
||||
CreatedAt: time.Now(),
|
||||
Labels: map[string]string{
|
||||
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project",
|
||||
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json",
|
||||
},
|
||||
}
|
||||
|
||||
// Given: A terraform-defined devcontainer whose
|
||||
// workspace_folder is the HOST-side path (set by provisioner).
|
||||
terraformDevcontainer = codersdk.WorkspaceAgentDevcontainer{
|
||||
ID: uuid.New(),
|
||||
Name: "terraform-devcontainer",
|
||||
WorkspaceFolder: "/home/coder/project",
|
||||
ConfigPath: "/home/coder/project/.devcontainer/devcontainer.json",
|
||||
SubagentID: uuid.NullUUID{UUID: terraformAgentID, Valid: true},
|
||||
}
|
||||
|
||||
fCCLI = &fakeContainerCLI{
|
||||
containers: codersdk.WorkspaceAgentListContainersResponse{
|
||||
Containers: []codersdk.WorkspaceAgentContainer{terraformContainer},
|
||||
},
|
||||
arch: runtime.GOARCH,
|
||||
}
|
||||
|
||||
// Given: devcontainer read-configuration returns the
|
||||
// CONTAINER-INTERNAL workspace folder.
|
||||
fDCCLI = &fakeDevcontainerCLI{
|
||||
upID: containerID,
|
||||
readConfig: agentcontainers.DevcontainerConfig{
|
||||
Workspace: agentcontainers.DevcontainerWorkspace{
|
||||
WorkspaceFolder: "/workspaces/project",
|
||||
},
|
||||
MergedConfiguration: agentcontainers.DevcontainerMergedConfiguration{
|
||||
Customizations: agentcontainers.DevcontainerMergedCustomizations{
|
||||
Coder: []agentcontainers.CoderCustomization{{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mSAC = acmock.NewMockSubAgentClient(mCtrl)
|
||||
createCalls = make(chan agentcontainers.SubAgent, 1)
|
||||
closed bool
|
||||
)
|
||||
|
||||
mSAC.EXPECT().List(gomock.Any()).Return([]agentcontainers.SubAgent{}, nil).AnyTimes()
|
||||
|
||||
mSAC.EXPECT().Create(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, agent agentcontainers.SubAgent) (agentcontainers.SubAgent, error) {
|
||||
agent.AuthToken = uuid.New()
|
||||
createCalls <- agent
|
||||
return agent, nil
|
||||
},
|
||||
).Times(1)
|
||||
|
||||
mSAC.EXPECT().Delete(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
|
||||
assert.True(t, closed, "Delete should only be called after Close")
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
|
||||
api := agentcontainers.NewAPI(logger,
|
||||
agentcontainers.WithContainerCLI(fCCLI),
|
||||
agentcontainers.WithDevcontainerCLI(fDCCLI),
|
||||
agentcontainers.WithDevcontainers(
|
||||
[]codersdk.WorkspaceAgentDevcontainer{terraformDevcontainer},
|
||||
[]codersdk.WorkspaceAgentScript{{ID: terraformDevcontainer.ID, LogSourceID: uuid.New()}},
|
||||
),
|
||||
agentcontainers.WithSubAgentClient(mSAC),
|
||||
agentcontainers.WithSubAgentURL("test-subagent-url"),
|
||||
agentcontainers.WithWatcher(watcher.NewNoop()),
|
||||
)
|
||||
api.Start()
|
||||
defer func() {
|
||||
closed = true
|
||||
api.Close()
|
||||
}()
|
||||
|
||||
// When: The devcontainer is created (triggering injection).
|
||||
err := api.CreateDevcontainer(terraformDevcontainer.WorkspaceFolder, terraformDevcontainer.ConfigPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: The subagent sent to Create has the correct
|
||||
// container-internal directory, not the host path.
|
||||
createdAgent := testutil.RequireReceive(ctx, t, createCalls)
|
||||
assert.Equal(t, terraformAgentID, createdAgent.ID,
|
||||
"agent should use terraform-defined ID")
|
||||
assert.Equal(t, "/workspaces/project", createdAgent.Directory,
|
||||
"directory should be the container-internal path from devcontainer "+
|
||||
"read-configuration, not the host-side workspace_folder")
|
||||
})
|
||||
|
||||
t.Run("Error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -134,6 +134,33 @@ func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
|
||||
}, ResolvePaths(mcpConfigFile, workingDir)
|
||||
}
|
||||
|
||||
// ContextPartsFromDir reads instruction files and discovers skills
|
||||
// from a specific directory, using default file names. This is used
|
||||
// by the CLI chat context commands to read context from an arbitrary
|
||||
// directory without consulting agent env vars.
|
||||
func ContextPartsFromDir(dir string) []codersdk.ChatMessagePart {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
|
||||
if entry, found := readInstructionFileFromDir(dir, DefaultInstructionsFile); found {
|
||||
parts = append(parts, entry)
|
||||
}
|
||||
|
||||
// Reuse ResolvePaths so CLI skill discovery follows the same
|
||||
// project-relative path handling as agent config resolution.
|
||||
skillParts := discoverSkills(
|
||||
ResolvePaths(strings.Join([]string{DefaultSkillsDir, "skills"}, ","), dir),
|
||||
DefaultSkillMetaFile,
|
||||
)
|
||||
parts = append(parts, skillParts...)
|
||||
|
||||
// Guarantee non-nil slice.
|
||||
if parts == nil {
|
||||
parts = []codersdk.ChatMessagePart{}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// MCPConfigFiles returns the resolved MCP configuration file
|
||||
// paths for the agent's MCP manager.
|
||||
func (api *API) MCPConfigFiles() []string {
|
||||
|
||||
@@ -23,18 +23,144 @@ func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartTyp
|
||||
return out
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
func writeSkillMetaFileInRoot(t *testing.T, skillsRoot, name, description string) string {
|
||||
t.Helper()
|
||||
|
||||
// Clear all env vars so defaults are used.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
skillDir := filepath.Join(skillsRoot, name)
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
return skillDir
|
||||
}
|
||||
|
||||
func writeSkillMetaFile(t *testing.T, dir, name, description string) string {
|
||||
t.Helper()
|
||||
return writeSkillMetaFileInRoot(t, filepath.Join(dir, ".agents", "skills"), name, description)
|
||||
}
|
||||
|
||||
func TestContextPartsFromDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsInstructionFilePart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("project instructions"), 0o600))
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Empty(t, skillParts)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "project instructions", contextParts[0].ContextFileContent)
|
||||
require.False(t, contextParts[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFile(t, dir, "my-skill", "A test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillPartsFromSkillsDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFileInRoot(
|
||||
t,
|
||||
filepath.Join(dir, "skills"),
|
||||
"my-skill",
|
||||
"A test skill",
|
||||
)
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsEmptyForEmptyDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(t.TempDir())
|
||||
|
||||
require.NotNil(t, parts)
|
||||
require.Empty(t, parts)
|
||||
})
|
||||
|
||||
t.Run("ReturnsCombinedResults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("combined instructions"), 0o600))
|
||||
skillDir := writeSkillMetaFile(t, dir, "combined-skill", "Combined test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 2)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "combined instructions", contextParts[0].ContextFileContent)
|
||||
require.Equal(t, "combined-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
})
|
||||
}
|
||||
|
||||
func setupConfigTestEnv(t *testing.T, overrides map[string]string) string {
|
||||
t.Helper()
|
||||
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
for key, value := range overrides {
|
||||
t.Setenv(key, value)
|
||||
}
|
||||
|
||||
return fakeHome
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -46,20 +172,18 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CustomEnvVars", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
optInstructions := t.TempDir()
|
||||
optSkills := t.TempDir()
|
||||
optMCP := platformAbsPath("opt", "mcp.json")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: optInstructions,
|
||||
agentcontextconfig.EnvInstructionsFile: "CUSTOM.md",
|
||||
agentcontextconfig.EnvSkillsDirs: optSkills,
|
||||
agentcontextconfig.EnvSkillMetaFile: "META.yaml",
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
|
||||
// Create files matching the custom names so we can
|
||||
// verify the env vars actually change lookup behavior.
|
||||
@@ -85,15 +209,12 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("WhitespaceInFileNames", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsFile: " CLAUDE.md ",
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// Create a file matching the trimmed name.
|
||||
@@ -106,19 +227,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CommaSeparatedDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
a := t.TempDir()
|
||||
b := t.TempDir()
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: a + "," + b,
|
||||
})
|
||||
|
||||
// Put instruction files in both dirs.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
|
||||
@@ -133,17 +248,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsInstructionFiles", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
|
||||
// Create ~/.coder/AGENTS.md
|
||||
coderDir := filepath.Join(fakeHome, ".coder")
|
||||
@@ -164,16 +272,9 @@ func TestConfig(t *testing.T) {
|
||||
require.False(t, ctxFiles[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
|
||||
// Create AGENTS.md in the working directory.
|
||||
@@ -193,16 +294,9 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
largeContent := strings.Repeat("a", 64*1024+100)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
|
||||
@@ -215,79 +309,47 @@ func TestConfig(t *testing.T) {
|
||||
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
|
||||
})
|
||||
|
||||
t.Run("SanitizesHTMLComments", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
sanitizationTests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SanitizesHTMLComments",
|
||||
input: "visible\n<!-- hidden -->content",
|
||||
expected: "visible\ncontent",
|
||||
},
|
||||
{
|
||||
name: "SanitizesInvisibleUnicode",
|
||||
input: "before\u200bafter",
|
||||
expected: "beforeafter",
|
||||
},
|
||||
{
|
||||
name: "NormalizesCRLF",
|
||||
input: "line1\r\nline2\rline3",
|
||||
expected: "line1\nline2\nline3",
|
||||
},
|
||||
}
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
for _, tt := range sanitizationTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte(tt.input),
|
||||
0o600,
|
||||
))
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("visible\n<!-- hidden -->content"),
|
||||
0o600,
|
||||
))
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// U+200B (zero-width space) should be stripped.
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("before\u200bafter"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("NormalizesCRLF", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("line1\r\nline2\rline3"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, tt.expected, ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DiscoversSkills", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
@@ -320,17 +382,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkipsMissingDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: nonExistent,
|
||||
agentcontextconfig.EnvSkillsDirs: nonExistent,
|
||||
})
|
||||
|
||||
workDir := t.TempDir()
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
@@ -340,17 +398,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, cfg.Parts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
|
||||
optMCP := platformAbsPath("opt", "custom.json")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
|
||||
workDir := t.TempDir()
|
||||
_, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -358,14 +412,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{optMCP}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir := filepath.Join(workDir, "skills")
|
||||
@@ -385,14 +435,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, skillParts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir1 := filepath.Join(workDir, "skills1")
|
||||
|
||||
@@ -117,6 +117,10 @@ type Config struct {
|
||||
X11MaxPort *int
|
||||
// BlockFileTransfer restricts use of file transfer applications.
|
||||
BlockFileTransfer bool
|
||||
// BlockReversePortForwarding disables reverse port forwarding (ssh -R).
|
||||
BlockReversePortForwarding bool
|
||||
// BlockLocalPortForwarding disables local port forwarding (ssh -L).
|
||||
BlockLocalPortForwarding bool
|
||||
// ReportConnection.
|
||||
ReportConnection reportConnectionFunc
|
||||
// Experimental: allow connecting to running containers via Docker exec.
|
||||
@@ -190,7 +194,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
}
|
||||
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := newForwardedUnixHandler(logger)
|
||||
unixForwardHandler := newForwardedUnixHandler(logger, config.BlockReversePortForwarding)
|
||||
|
||||
metrics := newSSHServerMetrics(prometheusRegistry)
|
||||
s := &Server{
|
||||
@@ -229,8 +233,15 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains)
|
||||
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
|
||||
},
|
||||
"direct-streamlocal@openssh.com": directStreamLocalHandler,
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-streamlocal@openssh.com": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "unix local port forward blocked")
|
||||
_ = newChan.Reject(gossh.Prohibited, "local port forwarding is disabled")
|
||||
return
|
||||
}
|
||||
directStreamLocalHandler(srv, conn, newChan, ctx)
|
||||
},
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
},
|
||||
ConnectionFailedCallback: func(conn net.Conn, err error) {
|
||||
s.logger.Warn(ctx, "ssh connection failed",
|
||||
@@ -250,6 +261,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
// be set before we start listening.
|
||||
HostSigners: []ssh.Signer{},
|
||||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "local port forward blocked",
|
||||
slog.F("destination_host", destinationHost),
|
||||
slog.F("destination_port", destinationPort))
|
||||
return false
|
||||
}
|
||||
// Allow local port forwarding all!
|
||||
s.logger.Debug(ctx, "local port forward",
|
||||
slog.F("destination_host", destinationHost),
|
||||
@@ -260,6 +277,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
return true
|
||||
},
|
||||
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
if s.config.BlockReversePortForwarding {
|
||||
s.logger.Warn(ctx, "reverse port forward blocked",
|
||||
slog.F("bind_host", bindHost),
|
||||
slog.F("bind_port", bindPort))
|
||||
return false
|
||||
}
|
||||
// Allow reverse port forwarding all!
|
||||
s.logger.Debug(ctx, "reverse port forward",
|
||||
slog.F("bind_host", bindHost),
|
||||
|
||||
@@ -35,8 +35,9 @@ type forwardedStreamLocalPayload struct {
|
||||
// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding.
|
||||
type forwardedUnixHandler struct {
|
||||
sync.Mutex
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
blockReversePortForwarding bool
|
||||
}
|
||||
|
||||
type forwardKey struct {
|
||||
@@ -44,10 +45,11 @@ type forwardKey struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
|
||||
func newForwardedUnixHandler(log slog.Logger, blockReversePortForwarding bool) *forwardedUnixHandler {
|
||||
return &forwardedUnixHandler{
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
blockReversePortForwarding: blockReversePortForwarding,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +64,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
||||
|
||||
switch req.Type {
|
||||
case "streamlocal-forward@openssh.com":
|
||||
if h.blockReversePortForwarding {
|
||||
log.Warn(ctx, "unix reverse port forward blocked")
|
||||
return false, nil
|
||||
}
|
||||
var reqPayload streamLocalForwardPayload
|
||||
err := gossh.Unmarshal(req.Payload, &reqPayload)
|
||||
if err != nil {
|
||||
|
||||
+1141
-1038
File diff suppressed because it is too large
Load Diff
@@ -98,6 +98,21 @@ message Manifest {
|
||||
repeated WorkspaceApp apps = 11;
|
||||
repeated WorkspaceAgentMetadata.Description metadata = 12;
|
||||
repeated WorkspaceAgentDevcontainer devcontainers = 17;
|
||||
repeated WorkspaceSecret secrets = 19;
|
||||
}
|
||||
|
||||
// WorkspaceSecret is a secret included in the agent manifest
|
||||
// for injection into a workspace.
|
||||
message WorkspaceSecret {
|
||||
// Environment variable name to inject (e.g. "GITHUB_TOKEN").
|
||||
// Empty string means this secret is not injected as an env var.
|
||||
string env_name = 1;
|
||||
// File path to write the secret value to (e.g.
|
||||
// "~/.aws/credentials"). Empty string means this secret is not
|
||||
// written to a file.
|
||||
string file_path = 2;
|
||||
// The decrypted secret value.
|
||||
bytes value = 3;
|
||||
}
|
||||
|
||||
message WorkspaceAgentDevcontainer {
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -620,6 +622,11 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
defer artifact.Reader.Close()
|
||||
defer func() {
|
||||
if artifact.ThumbnailReader != nil {
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if artifact.Size > workspacesdk.MaxRecordingSize {
|
||||
a.logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
@@ -633,10 +640,60 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "video/mp4")
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
|
||||
// Discard the thumbnail if it exceeds the maximum size.
|
||||
// The server-side consumer also enforces this per-part, but
|
||||
// rejecting it here avoids streaming a large thumbnail over
|
||||
// the wire for nothing.
|
||||
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
|
||||
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.ThumbnailSize),
|
||||
slog.F("max_size", workspacesdk.MaxThumbnailSize),
|
||||
)
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
artifact.ThumbnailReader = nil
|
||||
artifact.ThumbnailSize = 0
|
||||
}
|
||||
|
||||
// The multipart response is best-effort: once WriteHeader(200) is
|
||||
// called, CreatePart failures produce a truncated response without
|
||||
// the closing boundary. The server-side consumer handles this
|
||||
// gracefully, preserving any parts read before the error.
|
||||
mw := multipart.NewWriter(rw)
|
||||
defer mw.Close()
|
||||
rw.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(rw, artifact.Reader)
|
||||
|
||||
// Part 1: video/mp4 (always present).
|
||||
videoPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
|
||||
a.logger.Warn(ctx, "failed to write video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Part 2: image/jpeg (present only when thumbnail was extracted).
|
||||
if artifact.ThumbnailReader != nil {
|
||||
thumbPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"image/jpeg"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create thumbnail multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
_, _ = io.Copy(thumbPart, artifact.ThumbnailReader)
|
||||
}
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
|
||||
@@ -4,12 +4,17 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -59,6 +64,8 @@ type fakeDesktop struct {
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
|
||||
thumbnailData []byte // if set, StopRecording includes a thumbnail
|
||||
|
||||
// Recording tracking (guarded by recMu).
|
||||
recMu sync.Mutex
|
||||
recordings map[string]string // ID → file path
|
||||
@@ -187,10 +194,15 @@ func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*age
|
||||
_ = file.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
artifact := &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
if f.thumbnailData != nil {
|
||||
artifact.ThumbnailReader = io.NopCloser(bytes.NewReader(f.thumbnailData))
|
||||
artifact.ThumbnailSize = int64(len(f.thumbnailData))
|
||||
}
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) RecordActivity() {
|
||||
@@ -785,8 +797,8 @@ func TestRecordingStartStop(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStartFails(t *testing.T) {
|
||||
@@ -847,8 +859,8 @@ func TestRecordingStartIdempotent(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStopIdempotent(t *testing.T) {
|
||||
@@ -872,7 +884,7 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop twice - both should succeed with identical data.
|
||||
var bodies [2][]byte
|
||||
var videoParts [2][]byte
|
||||
for i := range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
@@ -880,10 +892,10 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(recorder, request)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
|
||||
bodies[i] = recorder.Body.Bytes()
|
||||
parts := parseMultipartParts(t, recorder.Header().Get("Content-Type"), recorder.Body.Bytes())
|
||||
videoParts[i] = parts["video/mp4"]
|
||||
}
|
||||
assert.Equal(t, bodies[0], bodies[1])
|
||||
assert.Equal(t, videoParts[0], videoParts[1])
|
||||
}
|
||||
|
||||
func TestRecordingStopInvalidIDFormat(t *testing.T) {
|
||||
@@ -1004,8 +1016,8 @@ func TestRecordingMultipleSimultaneous(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, expected[id], rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, expected[id], parts["video/mp4"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1112,8 +1124,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
firstData := rr.Body.Bytes()
|
||||
firstParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
firstData := firstParts["video/mp4"]
|
||||
require.NotEmpty(t, firstData)
|
||||
|
||||
// Step 3: Start again with the same ID - should succeed
|
||||
@@ -1128,8 +1140,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
secondData := rr.Body.Bytes()
|
||||
secondParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
secondData := secondParts["video/mp4"]
|
||||
require.NotEmpty(t, secondData)
|
||||
|
||||
// The two recordings should have different data because the
|
||||
@@ -1235,3 +1247,166 @@ func TestRecordingStopCorrupted(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording is corrupted.", respStop.Message)
|
||||
}
|
||||
|
||||
// parseMultipartParts parses a multipart/mixed response and returns
|
||||
// a map from Content-Type to body bytes.
|
||||
func parseMultipartParts(t *testing.T, contentType string, body []byte) map[string][]byte {
|
||||
t.Helper()
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
require.NoError(t, err, "parse Content-Type")
|
||||
boundary := params["boundary"]
|
||||
require.NotEmpty(t, boundary, "missing boundary")
|
||||
mr := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
parts := make(map[string][]byte)
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err, "unexpected multipart parse error")
|
||||
ct := part.Header.Get("Content-Type")
|
||||
data, readErr := io.ReadAll(part)
|
||||
require.NoError(t, readErr)
|
||||
parts[ct] = data
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create a fake JPEG header: 0xFF 0xD8 0xFF followed by 509 zero bytes.
|
||||
thumbnail := make([]byte, 512)
|
||||
thumbnail[0] = 0xff
|
||||
thumbnail[1] = 0xd8
|
||||
thumbnail[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: thumbnail,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 2, "expected exactly 2 parts (video + thumbnail)")
|
||||
|
||||
// The fake writes "fake-mp4-data-<id>-<counter>" as the MP4 content.
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
assert.Equal(t, thumbnail, parts["image/jpeg"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_NoThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_OversizedThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create thumbnail data that exceeds MaxThumbnailSize.
|
||||
oversizedThumb := make([]byte, workspacesdk.MaxThumbnailSize+1)
|
||||
oversizedThumb[0] = 0xff
|
||||
oversizedThumb[1] = 0xd8
|
||||
oversizedThumb[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: oversizedThumb,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response contains only the video part.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only, oversized thumbnail discarded)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
@@ -105,6 +105,11 @@ type RecordingArtifact struct {
|
||||
Reader io.ReadCloser
|
||||
// Size is the byte length of the MP4 content.
|
||||
Size int64
|
||||
// ThumbnailReader is the JPEG thumbnail. May be nil if no
|
||||
// thumbnail was produced. Callers must close it when done.
|
||||
ThumbnailReader io.ReadCloser
|
||||
// ThumbnailSize is the byte length of the thumbnail.
|
||||
ThumbnailSize int64
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
|
||||
@@ -56,6 +56,7 @@ type screenshotOutput struct {
|
||||
type recordingProcess struct {
|
||||
cmd *exec.Cmd
|
||||
filePath string
|
||||
thumbPath string
|
||||
stopped bool
|
||||
killed bool // true when the process was SIGKILLed
|
||||
done chan struct{} // closed when cmd.Wait() returns
|
||||
@@ -383,13 +384,20 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old recording file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old thumbnail file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, recordingID)
|
||||
}
|
||||
|
||||
@@ -406,6 +414,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".thumb.jpg")
|
||||
|
||||
// Use a background context so the process outlives the HTTP
|
||||
// request that triggered it.
|
||||
@@ -419,6 +428,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
"--idle-speedup", "20",
|
||||
"--idle-min-duration", "0.35",
|
||||
"--idle-noise-tolerance", "-38dB",
|
||||
"--thumbnail", thumbPath,
|
||||
filePath)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
@@ -427,9 +437,10 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
rec := &recordingProcess{
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
done: make(chan struct{}),
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
thumbPath: thumbPath,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
rec.waitErr = cmd.Wait()
|
||||
@@ -499,10 +510,35 @@ func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string)
|
||||
_ = f.Close()
|
||||
return nil, xerrors.Errorf("stat recording artifact: %w", err)
|
||||
}
|
||||
return &RecordingArtifact{
|
||||
artifact := &RecordingArtifact{
|
||||
Reader: f,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
// Attach thumbnail if the subprocess wrote one.
|
||||
thumbFile, err := os.Open(rec.thumbPath)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "thumbnail not available",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
thumbInfo, err := thumbFile.Stat()
|
||||
if err != nil {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail stat failed",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
if thumbInfo.Size() == 0 {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail file is empty",
|
||||
slog.F("thumbnail_path", rec.thumbPath))
|
||||
return artifact, nil
|
||||
}
|
||||
artifact.ThumbnailReader = thumbFile
|
||||
artifact.ThumbnailSize = thumbInfo.Size()
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
// lockedStopRecordingProcess stops a single recording via stopOnce.
|
||||
@@ -571,18 +607,33 @@ func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
|
||||
}
|
||||
info, err := os.Stat(rec.filePath)
|
||||
if err != nil {
|
||||
// File already removed or inaccessible; drop entry.
|
||||
// File already removed or inaccessible; clean up
|
||||
// any leftover thumbnail and drop the entry.
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
continue
|
||||
}
|
||||
if p.clock.Since(info.ModTime()) > time.Hour {
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale recording file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
}
|
||||
@@ -603,13 +654,14 @@ func (p *portableDesktop) Close() error {
|
||||
// Snapshot recording file paths and idle goroutine channels
|
||||
// for cleanup, then clear the map.
|
||||
type recEntry struct {
|
||||
id string
|
||||
filePath string
|
||||
idleDone chan struct{}
|
||||
id string
|
||||
filePath string
|
||||
thumbPath string
|
||||
idleDone chan struct{}
|
||||
}
|
||||
var allRecs []recEntry
|
||||
for id, rec := range p.recordings {
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, thumbPath: rec.thumbPath, idleDone: rec.idleDone})
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
session := p.session
|
||||
@@ -630,13 +682,20 @@ func (p *portableDesktop) Close() error {
|
||||
go func() {
|
||||
defer close(cleanupDone)
|
||||
for _, entry := range allRecs {
|
||||
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(entry.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove recording file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("file_path", entry.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(entry.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove thumbnail file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("thumbnail_path", entry.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
session.cancel()
|
||||
|
||||
@@ -2,6 +2,7 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -584,6 +585,7 @@ func TestPortableDesktop_StartRecording(t *testing.T) {
|
||||
joined := strings.Join(cmd, " ")
|
||||
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
|
||||
found = true
|
||||
assert.Contains(t, joined, "--thumbnail", "record command should include --thumbnail flag")
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -666,6 +668,66 @@ func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
|
||||
defer artifact.Reader.Close()
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// No thumbnail file exists, so ThumbnailReader should be nil.
|
||||
assert.Nil(t, artifact.ThumbnailReader, "ThumbnailReader should be nil when no thumbnail file exists")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write a dummy MP4 file at the expected path.
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(filePath) })
|
||||
|
||||
// Write a thumbnail file at the expected path.
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".thumb.jpg")
|
||||
thumbContent := []byte("fake-jpeg-thumbnail")
|
||||
require.NoError(t, os.WriteFile(thumbPath, thumbContent, 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(thumbPath) })
|
||||
|
||||
artifact, err := pd.StopRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
defer artifact.Reader.Close()
|
||||
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// Thumbnail should be attached.
|
||||
require.NotNil(t, artifact.ThumbnailReader, "ThumbnailReader should be non-nil when thumbnail file exists")
|
||||
defer artifact.ThumbnailReader.Close()
|
||||
assert.Equal(t, int64(len(thumbContent)), artifact.ThumbnailSize)
|
||||
|
||||
// Read and verify thumbnail content.
|
||||
thumbData, err := io.ReadAll(artifact.ThumbnailReader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, thumbContent, thumbData)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
@@ -750,12 +812,18 @@ func TestPortableDesktop_IdleTimeout_StopsRecordings(t *testing.T) {
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout to trigger the stop-all.
|
||||
clk.Advance(idleTimeout)
|
||||
clk.Advance(idleTimeout).MustWait(ctx)
|
||||
|
||||
// Wait for the stop timer to be created, then release it.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Advance past the 15s stop timeout so the process is
|
||||
// forcibly killed. Without this the test depends on the real
|
||||
// shell handling SIGINT promptly, which is unreliable on
|
||||
// macOS CI runners (the flake in #1461).
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
|
||||
// The recording process should now be stopped.
|
||||
require.Eventually(t, func() bool {
|
||||
pd.mu.Lock()
|
||||
@@ -877,11 +945,17 @@ func TestPortableDesktop_IdleTimeout_MultipleRecordings(t *testing.T) {
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout.
|
||||
clk.Advance(idleTimeout)
|
||||
clk.Advance(idleTimeout).MustWait(ctx)
|
||||
|
||||
// Wait for both stop timers.
|
||||
// Each idle monitor goroutine serializes on p.mu, so the
|
||||
// second stop timer is only created after the first stop
|
||||
// completes. Advance past the 15s stop timeout after each
|
||||
// release so the process is forcibly killed instead of
|
||||
// depending on SIGINT (unreliable on macOS — see #1461).
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Both recordings should be stopped.
|
||||
|
||||
@@ -87,6 +87,12 @@ func IsDevVersion(v string) bool {
|
||||
return strings.Contains(v, "-"+develPreRelease)
|
||||
}
|
||||
|
||||
// IsRCVersion returns true if the version has a release candidate
|
||||
// pre-release tag, e.g. "v2.31.0-rc.0".
|
||||
func IsRCVersion(v string) bool {
|
||||
return strings.Contains(v, "-rc.")
|
||||
}
|
||||
|
||||
// IsDev returns true if this is a development build.
|
||||
// CI builds are also considered development builds.
|
||||
func IsDev() bool {
|
||||
|
||||
@@ -102,3 +102,29 @@ func TestBuildInfo(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRCVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected bool
|
||||
}{
|
||||
{"RC0", "v2.31.0-rc.0", true},
|
||||
{"RC1WithBuild", "v2.31.0-rc.1+abc123", true},
|
||||
{"RC10", "v2.31.0-rc.10", true},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", true},
|
||||
{"DevelVersion", "v2.31.0-devel+abc123", false},
|
||||
{"StableVersion", "v2.31.0", false},
|
||||
{"DevNoVersion", "v0.0.0-devel+abc123", false},
|
||||
{"BetaVersion", "v2.31.0-beta.1", false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, buildinfo.IsRCVersion(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+22
-4
@@ -53,6 +53,8 @@ func workspaceAgent() *serpent.Command {
|
||||
slogJSONPath string
|
||||
slogStackdriverPath string
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
agentHeaderCommand string
|
||||
agentHeader []string
|
||||
devcontainers bool
|
||||
@@ -319,10 +321,12 @@ func workspaceAgent() *serpent.Command {
|
||||
SSHMaxTimeout: sshMaxTimeout,
|
||||
Subsystems: subsystems,
|
||||
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
BlockReversePortForwarding: blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: blockLocalPortForwarding,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
DevcontainerAPIOptions: []agentcontainers.Option{
|
||||
agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()),
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
@@ -493,6 +497,20 @@ func workspaceAgent() *serpent.Command {
|
||||
Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")),
|
||||
Value: serpent.BoolOf(&blockFileTransfer),
|
||||
},
|
||||
{
|
||||
Flag: "block-reverse-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING",
|
||||
Description: "Block reverse port forwarding through the SSH server (ssh -R).",
|
||||
Value: serpent.BoolOf(&blockReversePortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "block-local-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING",
|
||||
Description: "Block local port forwarding through the SSH server (ssh -L).",
|
||||
Value: serpent.BoolOf(&blockLocalPortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "devcontainers-enable",
|
||||
Default: "true",
|
||||
|
||||
+194
@@ -0,0 +1,194 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) chatCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "chat",
|
||||
Short: "Manage agent chats",
|
||||
Long: "Commands for interacting with chats from within a workspace.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) chatContextCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "context",
|
||||
Short: "Manage chat context",
|
||||
Long: "Add or clear context files and skills for an active chat session.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextAddCommand(),
|
||||
r.chatContextClearCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextAddCommand() *serpent.Command {
|
||||
var (
|
||||
dir string
|
||||
chatID string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "add",
|
||||
Short: "Add context to an active chat",
|
||||
Long: "Read instruction files and discover skills from a directory, then add " +
|
||||
"them as context to an active chat session. Multiple calls " +
|
||||
"are additive.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
if dir == "" && inv.Environ.Get("CODER") != "true" {
|
||||
return xerrors.New("this command must be run inside a Coder workspace (set --dir to override)")
|
||||
}
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedDir := dir
|
||||
if resolvedDir == "" {
|
||||
resolvedDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get working directory: %w", err)
|
||||
}
|
||||
}
|
||||
resolvedDir, err = filepath.Abs(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve directory: %w", err)
|
||||
}
|
||||
info, err := os.Stat(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cannot read directory %q: %w", resolvedDir, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return xerrors.Errorf("%q is not a directory", resolvedDir)
|
||||
}
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(resolvedDir)
|
||||
if len(parts) == 0 {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "No context files or skills found in "+resolvedDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve chat ID from flag or auto-detect.
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.AddChatContext(ctx, agentsdk.AddChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
Parts: parts,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("add chat context: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Added %d context part(s) to chat %s\n", resp.Count, resp.ChatID)
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "Directory",
|
||||
Flag: "dir",
|
||||
Description: "Directory to read context files and skills from. Defaults to the current working directory.",
|
||||
Value: serpent.StringOf(&dir),
|
||||
},
|
||||
{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to add context to. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
},
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextClearCommand() *serpent.Command {
|
||||
var chatID string
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "clear",
|
||||
Short: "Clear context from an active chat",
|
||||
Long: "Soft-delete all context-file and skill messages from an active chat. " +
|
||||
"The next turn will re-fetch default context from the agent.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear chat context: %w", err)
|
||||
}
|
||||
|
||||
if resp.ChatID == uuid.Nil {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "No active chats to clear.")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Cleared context from chat %s\n", resp.ChatID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to clear context from. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
}},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// parseChatID returns the chat UUID from the flag value (which
|
||||
// serpent already populates from --chat or CODER_CHAT_ID). Returns
|
||||
// uuid.Nil if empty (the server will auto-detect).
|
||||
func parseChatID(flagValue string) (uuid.UUID, error) {
|
||||
if flagValue == "" {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
parsed, err := uuid.Parse(flagValue)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("invalid chat ID %q: %w", flagValue, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
)
|
||||
|
||||
func TestExpChatContextAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RequiresWorkspaceOrDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
|
||||
err := inv.Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
})
|
||||
|
||||
t.Run("AllowsExplicitDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add", "--dir", t.TempDir())
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AllowsWorkspaceEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
inv.Environ.Set("CODER", "true")
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
+29
-5
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -148,6 +149,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
|
||||
return []*serpent.Command{
|
||||
r.scaletestCmd(),
|
||||
r.errorExample(),
|
||||
r.chatCommand(),
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
@@ -710,7 +712,7 @@ func (r *RootCmd) createHTTPClient(ctx context.Context, serverURL *url.URL, inv
|
||||
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||
transport = wrapTransportWithUserAgentHeader(transport, inv)
|
||||
if !r.noVersionCheck {
|
||||
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
transport = wrapTransportWithVersionCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
// Create a new client without any wrapped transport
|
||||
// otherwise it creates an infinite loop!
|
||||
basicClient := codersdk.New(serverURL)
|
||||
@@ -1434,6 +1436,21 @@ func defaultUpgradeMessage(version string) string {
|
||||
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
|
||||
}
|
||||
|
||||
// serverVersionMessage returns a warning message if the server version
|
||||
// is a release candidate or development build. Returns empty string
|
||||
// for stable versions. RC is checked before devel because RC dev
|
||||
// builds (e.g. v2.33.0-rc.1-devel+hash) contain both tags.
|
||||
func serverVersionMessage(serverVersion string) string {
|
||||
switch {
|
||||
case buildinfo.IsRCVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a release candidate of Coder (%s)", serverVersion)
|
||||
case buildinfo.IsDevVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a development version of Coder (%s)", serverVersion)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
|
||||
// that checks for entitlement warnings and prints them to the user.
|
||||
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
|
||||
@@ -1452,10 +1469,10 @@ func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.
|
||||
})
|
||||
}
|
||||
|
||||
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
|
||||
// that checks for version mismatches between the client and server. If a mismatch
|
||||
// is detected, a warning is printed to the user.
|
||||
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
// wrapTransportWithVersionCheck adds a middleware to the HTTP transport
|
||||
// that checks the server version and warns about development builds,
|
||||
// release candidates, and client/server version mismatches.
|
||||
func wrapTransportWithVersionCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
var once sync.Once
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.RoundTrip(req)
|
||||
@@ -1467,9 +1484,16 @@ func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.In
|
||||
if serverVersion == "" {
|
||||
return
|
||||
}
|
||||
// Warn about non-stable server versions. Skip
|
||||
// during tests to avoid polluting golden files.
|
||||
if msg := serverVersionMessage(serverVersion); msg != "" && flag.Lookup("test.v") == nil {
|
||||
warning := pretty.Sprint(cliui.DefaultStyles.Warn, msg)
|
||||
_, _ = fmt.Fprintln(inv.Stderr, warning)
|
||||
}
|
||||
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
|
||||
if serverInfo, err := getBuildInfo(inv.Context()); err == nil {
|
||||
switch {
|
||||
|
||||
@@ -91,7 +91,7 @@ func Test_formatExamples(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoOutput", func(t *testing.T) {
|
||||
@@ -102,7 +102,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -131,7 +131,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
expectedUpgradeMessage := "My custom upgrade message"
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -159,6 +159,53 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
||||
require.Equal(t, expectedOutput, buf.String())
|
||||
})
|
||||
|
||||
t.Run("ServerStableVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &RootCmd{}
|
||||
cmd, err := r.Command(nil)
|
||||
require.NoError(t, err)
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
codersdk.BuildVersionHeader: []string{"v2.31.0"},
|
||||
},
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), inv, "v2.31.0", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Empty(t, buf.String())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_serverVersionMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected string
|
||||
}{
|
||||
{"Stable", "v2.31.0", ""},
|
||||
{"Dev", "v0.0.0-devel+abc123", "the server is running a development version of Coder (v0.0.0-devel+abc123)"},
|
||||
{"RC", "v2.31.0-rc.1", "the server is running a release candidate of Coder (v2.31.0-rc.1)"},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", "the server is running a release candidate of Coder (v2.33.0-rc.1-devel+727ec00f7)"},
|
||||
{"Empty", "", ""},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, serverVersionMessage(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
|
||||
|
||||
+6
-4
@@ -69,15 +69,17 @@ var (
|
||||
// isRetryableError checks for transient connection errors worth
|
||||
// retrying: DNS failures, connection refused, and server 5xx.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
if err == nil || xerrors.Is(err, context.Canceled) {
|
||||
return false
|
||||
}
|
||||
// Check connection errors before context.DeadlineExceeded because
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both.
|
||||
if codersdk.IsConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
if xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
return sdkErr.StatusCode() >= 500
|
||||
|
||||
@@ -516,6 +516,23 @@ func TestIsRetryableError(t *testing.T) {
|
||||
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
|
||||
})
|
||||
}
|
||||
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both
|
||||
// IsConnectionError and context.DeadlineExceeded. Verify it is retryable.
|
||||
t.Run("DialTimeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
|
||||
defer cancel()
|
||||
<-ctx.Done() // ensure deadline has fired
|
||||
_, err := (&net.Dialer{}).DialContext(ctx, "tcp", "127.0.0.1:1")
|
||||
require.Error(t, err)
|
||||
// Proves the ambiguity: this error matches BOTH checks.
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorAs(t, err, new(*net.OpError))
|
||||
assert.True(t, isRetryableError(err))
|
||||
// Also when wrapped, as runCoderConnectStdio does.
|
||||
assert.True(t, isRetryableError(xerrors.Errorf("dial coder connect: %w", err)))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryWithInterval(t *testing.T) {
|
||||
|
||||
+6
@@ -39,6 +39,12 @@ OPTIONS:
|
||||
--block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false)
|
||||
Block file transfer using known applications: nc,rsync,scp,sftp.
|
||||
|
||||
--block-local-port-forwarding bool, $CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING (default: false)
|
||||
Block local port forwarding through the SSH server (ssh -L).
|
||||
|
||||
--block-reverse-port-forwarding bool, $CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING (default: false)
|
||||
Block reverse port forwarding through the SSH server (ssh -R).
|
||||
|
||||
--boundary-log-proxy-socket-path string, $CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH (default: /tmp/boundary-audit.sock)
|
||||
The path for the boundary log proxy server Unix socket. Boundary
|
||||
should write audit logs to this socket.
|
||||
|
||||
+1
-1
@@ -11,7 +11,7 @@ OPTIONS:
|
||||
-O, --org string, $CODER_ORGANIZATION
|
||||
Select which organization (uuid or name) to use.
|
||||
|
||||
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
|
||||
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|workspace build transition|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
|
||||
Columns to display in table output.
|
||||
|
||||
-i, --initiator string, $CODER_PROVISIONER_JOB_LIST_INITIATOR
|
||||
|
||||
@@ -58,7 +58,8 @@
|
||||
"template_display_name": "",
|
||||
"template_icon": "",
|
||||
"workspace_id": "===========[workspace ID]===========",
|
||||
"workspace_name": "test-workspace"
|
||||
"workspace_name": "test-workspace",
|
||||
"workspace_build_transition": "start"
|
||||
},
|
||||
"logs_overflowed": false,
|
||||
"organization_name": "Coder"
|
||||
|
||||
@@ -77,8 +77,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
|
||||
var dbLevel database.LogLevel
|
||||
switch logEntry.Level {
|
||||
|
||||
@@ -139,6 +139,59 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("SanitizesOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
now := dbtime.Now()
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
rawOutput := "before\x00middle\xc3\x28after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
|
||||
req := &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(now),
|
||||
Level: agentproto.Log_WARN,
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
|
||||
AgentID: agent.ID,
|
||||
LogSourceID: logSource.ID,
|
||||
CreatedAt: now,
|
||||
Output: []string{sanitizedOutput},
|
||||
Level: []database.LogLevel{database.LogLevelWarn},
|
||||
OutputLength: expectedOutputLength,
|
||||
}).Return([]database.WorkspaceAgentLog{
|
||||
{
|
||||
AgentID: agent.ID,
|
||||
CreatedAt: now,
|
||||
ID: 1,
|
||||
Output: sanitizedOutput,
|
||||
Level: database.LogLevelWarn,
|
||||
LogSourceID: logSource.ID,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
})
|
||||
|
||||
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
|
||||
// An ID is only given in the request when it is a terraform-defined devcontainer
|
||||
// that has attached resources. These subagents are pre-provisioned by terraform
|
||||
// (the agent record already exists), so we update configurable fields like
|
||||
// display_apps rather than creating a new agent.
|
||||
// display_apps and directory rather than creating a new agent.
|
||||
if req.Id != nil {
|
||||
id, err := uuid.FromBytes(req.Id)
|
||||
if err != nil {
|
||||
@@ -97,6 +97,16 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
|
||||
return nil, xerrors.Errorf("update workspace agent display apps: %w", err)
|
||||
}
|
||||
|
||||
if req.Directory != "" {
|
||||
if err := a.Database.UpdateWorkspaceAgentDirectoryByID(ctx, database.UpdateWorkspaceAgentDirectoryByIDParams{
|
||||
ID: id,
|
||||
Directory: req.Directory,
|
||||
UpdatedAt: createdAt,
|
||||
}); err != nil {
|
||||
return nil, xerrors.Errorf("update workspace agent directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &agentproto.CreateSubAgentResponse{
|
||||
Agent: &agentproto.SubAgent{
|
||||
Name: subAgent.Name,
|
||||
|
||||
@@ -1267,11 +1267,11 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
agentID, err := uuid.FromBytes(resp.Agent.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// And: The database agent's other fields are unchanged.
|
||||
// And: The database agent's name, architecture, and OS are unchanged.
|
||||
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, baseChildAgent.Name, updatedAgent.Name)
|
||||
require.Equal(t, baseChildAgent.Directory, updatedAgent.Directory)
|
||||
require.Equal(t, "/different/path", updatedAgent.Directory)
|
||||
require.Equal(t, baseChildAgent.Architecture, updatedAgent.Architecture)
|
||||
require.Equal(t, baseChildAgent.OperatingSystem, updatedAgent.OperatingSystem)
|
||||
|
||||
@@ -1280,6 +1280,42 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
require.Equal(t, database.DisplayAppWebTerminal, updatedAgent.DisplayApps[0])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OK_DirectoryUpdated",
|
||||
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
|
||||
// Given: An existing child agent with a stale host-side
|
||||
// directory (as set by the provisioner at build time).
|
||||
childAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ParentID: uuid.NullUUID{Valid: true, UUID: agent.ID},
|
||||
ResourceID: agent.ResourceID,
|
||||
Name: baseChildAgent.Name,
|
||||
Directory: "/home/coder/project",
|
||||
Architecture: baseChildAgent.Architecture,
|
||||
OperatingSystem: baseChildAgent.OperatingSystem,
|
||||
DisplayApps: baseChildAgent.DisplayApps,
|
||||
})
|
||||
|
||||
// When: Agent injection sends the correct
|
||||
// container-internal path.
|
||||
return &proto.CreateSubAgentRequest{
|
||||
Id: childAgent.ID[:],
|
||||
Directory: "/workspaces/project",
|
||||
DisplayApps: []proto.CreateSubAgentRequest_DisplayApp{
|
||||
proto.CreateSubAgentRequest_WEB_TERMINAL,
|
||||
},
|
||||
}
|
||||
},
|
||||
check: func(t *testing.T, ctx context.Context, db database.Store, resp *proto.CreateSubAgentResponse, agent database.WorkspaceAgent) {
|
||||
agentID, err := uuid.FromBytes(resp.Agent.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: Directory is updated to the container-internal
|
||||
// path.
|
||||
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "/workspaces/project", updatedAgent.Directory)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error/MalformedID",
|
||||
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
|
||||
|
||||
Generated
+359
@@ -1266,6 +1266,68 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/experimental/chats/config/retention-days": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Get chat retention days",
|
||||
"operationId": "get-chat-retention-days",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Update chat retention days",
|
||||
"operationId": "update-chat-retention-days",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/experimental/watch-all-workspacebuilds": {
|
||||
"get": {
|
||||
"produces": [
|
||||
@@ -9452,6 +9514,212 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": [
|
||||
@@ -13177,6 +13445,12 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -14420,6 +14694,14 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatRetentionDaysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -15072,6 +15354,26 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -18847,6 +19149,9 @@ const docTemplate = `{
|
||||
"template_version_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_build_transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -20952,6 +21257,14 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateChatRetentionDaysRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateCheckResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21193,6 +21506,23 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21648,6 +21978,35 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
Generated
+329
@@ -1103,6 +1103,60 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/experimental/chats/config/retention-days": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Get chat retention days",
|
||||
"operationId": "get-chat-retention-days",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"consumes": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Update chat retention days",
|
||||
"operationId": "update-chat-retention-days",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/experimental/watch-all-workspacebuilds": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
@@ -8377,6 +8431,190 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": ["application/json"],
|
||||
@@ -11755,6 +11993,12 @@
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -12963,6 +13207,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatRetentionDaysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -13581,6 +13833,26 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -17237,6 +17509,9 @@
|
||||
"template_version_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_build_transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -19243,6 +19518,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateChatRetentionDaysRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateCheckResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19475,6 +19758,23 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19905,6 +20205,35 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": ["active", "dormant", "suspended"],
|
||||
|
||||
@@ -1189,6 +1189,8 @@ func New(options *Options) *API {
|
||||
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
|
||||
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
|
||||
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
|
||||
r.Get("/retention-days", api.getChatRetentionDays)
|
||||
r.Put("/retention-days", api.putChatRetentionDays)
|
||||
r.Get("/template-allowlist", api.getChatTemplateAllowlist)
|
||||
r.Put("/template-allowlist", api.putChatTemplateAllowlist)
|
||||
})
|
||||
@@ -1243,6 +1245,7 @@ func New(options *Options) *API {
|
||||
r.Get("/git", api.watchChatGit)
|
||||
})
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Post("/tool-results", api.postChatToolResults)
|
||||
r.Post("/title/regenerate", api.regenerateChatTitle)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
@@ -1605,6 +1608,15 @@ func New(options *Options) *API {
|
||||
|
||||
r.Get("/gitsshkey", api.gitSSHKey)
|
||||
r.Put("/gitsshkey", api.regenerateGitSSHKey)
|
||||
r.Route("/secrets", func(r chi.Router) {
|
||||
r.Post("/", api.postUserSecret)
|
||||
r.Get("/", api.getUserSecrets)
|
||||
r.Route("/{name}", func(r chi.Router) {
|
||||
r.Get("/", api.getUserSecret)
|
||||
r.Patch("/", api.patchUserSecret)
|
||||
r.Delete("/", api.deleteUserSecret)
|
||||
})
|
||||
})
|
||||
r.Route("/notifications", func(r chi.Router) {
|
||||
r.Route("/preferences", func(r chi.Router) {
|
||||
r.Get("/", api.userNotificationPreferences)
|
||||
@@ -1650,6 +1662,10 @@ func New(options *Options) *API {
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
r.Post("/log-source", api.workspaceAgentPostLogSource)
|
||||
r.Get("/reinit", api.workspaceAgentReinit)
|
||||
r.Route("/experimental", func(r chi.Router) {
|
||||
r.Post("/chat-context", api.workspaceAgentAddChatContext)
|
||||
r.Delete("/chat-context", api.workspaceAgentClearChatContext)
|
||||
})
|
||||
r.Route("/tasks/{task}", func(r chi.Router) {
|
||||
r.Post("/log-snapshot", api.postWorkspaceAgentTaskLogSnapshot)
|
||||
})
|
||||
|
||||
@@ -147,6 +147,10 @@ func parseSwaggerComment(commentGroup *ast.CommentGroup) SwaggerComment {
|
||||
return c
|
||||
}
|
||||
|
||||
func isExperimentalEndpoint(route string) bool {
|
||||
return strings.HasPrefix(route, "/workspaceagents/me/experimental/")
|
||||
}
|
||||
|
||||
func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments []SwaggerComment) {
|
||||
assertUniqueRoutes(t, swaggerComments)
|
||||
assertSingleAnnotations(t, swaggerComments)
|
||||
@@ -165,6 +169,9 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [
|
||||
if strings.HasSuffix(route, "/*") {
|
||||
return
|
||||
}
|
||||
if isExperimentalEndpoint(route) {
|
||||
return
|
||||
}
|
||||
|
||||
c := findSwaggerCommentByMethodAndRoute(swaggerComments, method, route)
|
||||
assert.NotNil(t, c, "Missing @Router annotation")
|
||||
|
||||
@@ -538,6 +538,12 @@ func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator,
|
||||
switch {
|
||||
case workspaceAgent.Status != codersdk.WorkspaceAgentConnected && workspaceAgent.LifecycleState == codersdk.WorkspaceAgentLifecycleOff:
|
||||
workspaceAgent.Health.Reason = "agent is not running"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentConnecting:
|
||||
// Note: the case above catches connecting+off as "not running".
|
||||
// This case handles connecting agents with a non-off lifecycle
|
||||
// (e.g. "created" or "starting"), where the agent binary has
|
||||
// not yet established a connection to coderd.
|
||||
workspaceAgent.Health.Reason = "agent has not yet connected"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentTimeout:
|
||||
workspaceAgent.Health.Reason = "agent is taking too long to connect"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentDisconnected:
|
||||
@@ -1234,6 +1240,8 @@ func buildAIBridgeThread(
|
||||
if rootIntc != nil {
|
||||
thread.Model = rootIntc.Model
|
||||
thread.Provider = rootIntc.Provider
|
||||
thread.CredentialKind = string(rootIntc.CredentialKind)
|
||||
thread.CredentialHint = rootIntc.CredentialHint
|
||||
// Get first user prompt from root interception.
|
||||
// A thread can only have one prompt, by definition, since we currently
|
||||
// only store the last prompt observed in an interception.
|
||||
@@ -1715,3 +1723,41 @@ func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// UserSecret converts a database ListUserSecretsRow (metadata only,
|
||||
// no value) to an SDK UserSecret.
|
||||
func UserSecret(secret database.ListUserSecretsRow) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecretFromFull converts a full database UserSecret row to an
|
||||
// SDK UserSecret, omitting the value and encryption key ID.
|
||||
func UserSecretFromFull(secret database.UserSecret) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecrets converts a slice of database ListUserSecretsRow to
|
||||
// SDK UserSecret values.
|
||||
func UserSecrets(secrets []database.ListUserSecretsRow) []codersdk.UserSecret {
|
||||
result := make([]codersdk.UserSecret, 0, len(secrets))
|
||||
for _, s := range secrets {
|
||||
result = append(result, UserSecret(s))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -552,6 +552,10 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`),
|
||||
Valid: true,
|
||||
},
|
||||
DynamicTools: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`[{"name":"tool1","description":"test tool","inputSchema":{"type":"object"}}]`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
// Only ChatID is needed here. This test checks that
|
||||
// Chat.DiffStatus is non-nil, not that every DiffStatus
|
||||
|
||||
@@ -1708,6 +1708,17 @@ func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -2031,6 +2042,20 @@ func (q *querier) DeleteOldAuditLogs(ctx context.Context, arg database.DeleteOld
|
||||
return q.db.DeleteOldAuditLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteOldChatFiles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteOldChats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
@@ -2155,10 +2180,10 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
@@ -2399,6 +2424,10 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
return q.db.GetActiveAISeatCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetActiveChatsByAgentID)(ctx, agentID)
|
||||
}
|
||||
|
||||
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -2622,6 +2651,14 @@ func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.Ch
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
@@ -2670,6 +2707,14 @@ func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModel
|
||||
return q.db.GetChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatModelConfigsForTelemetry(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
@@ -2699,6 +2744,15 @@ func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) (
|
||||
return q.db.GetChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
// Chat retention is a deployment-wide config read by dbpurge.
|
||||
// Only requires a valid actor in context.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return 0, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatRetentionDays(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
// The system prompt is a deployment-wide setting read during chat
|
||||
// creation by every authenticated user, so no RBAC policy check
|
||||
@@ -2777,6 +2831,14 @@ func (q *querier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) (
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByWorkspaceIDs)(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatsUpdatedAfter(ctx, updatedAfter)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
// Just like with the audit logs query, shortcut if the user is an owner.
|
||||
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
|
||||
@@ -3339,11 +3401,11 @@ func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRI
|
||||
return q.db.GetPRInsightsPerModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
func (q *querier) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsRecentPRs(ctx, arg)
|
||||
return q.db.GetPRInsightsPullRequests(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
@@ -5681,6 +5743,17 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
|
||||
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
@@ -6710,6 +6783,19 @@ func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg da
|
||||
return q.db.UpdateWorkspaceAgentConnectionByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdateAgent, workspace); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return q.db.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
@@ -7031,6 +7117,13 @@ func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, incl
|
||||
return q.db.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatRetentionDays(ctx, retentionDays)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
|
||||
@@ -478,6 +478,24 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetActiveChatsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
agentID := uuid.New()
|
||||
dbm.EXPECT().GetActiveChatsByAgentID(gomock.Any(), agentID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(agentID).Asserts(chat, policy.ActionRead).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("SoftDeleteContextFileMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteContextFileMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("ClearChatMessageProviderResponseIDsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ClearChatMessageProviderResponseIDsByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerChatParams{
|
||||
OwnerID: uuid.New(),
|
||||
@@ -600,6 +618,22 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
|
||||
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("DeleteOldChatFiles", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().DeleteOldChatFiles(gomock.Any(), database.DeleteOldChatFilesParams{}).Return(int64(0), nil).AnyTimes()
|
||||
check.Args(database.DeleteOldChatFilesParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("DeleteOldChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().DeleteOldChats(gomock.Any(), database.DeleteOldChatsParams{}).Return(int64(0), nil).AnyTimes()
|
||||
check.Args(database.DeleteOldChatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("GetChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("UpsertChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatRetentionDays(gomock.Any(), int32(30)).Return(nil).AnyTimes()
|
||||
check.Args(int32(30)).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
@@ -2227,9 +2261,9 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsRecentPRsParams{}
|
||||
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
|
||||
s.Run("GetPRInsightsPullRequests", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsPullRequestsParams{}
|
||||
dbm.EXPECT().GetPRInsightsPullRequests(gomock.Any(), arg).Return([]database.GetPRInsightsPullRequestsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
@@ -2901,6 +2935,17 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("UpdateWorkspaceAgentDirectoryByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
arg := database.UpdateWorkspaceAgentDirectoryByIDParams{
|
||||
ID: agt.ID,
|
||||
Directory: "/workspaces/project",
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateWorkspaceAgentDirectoryByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionUpdateAgent).Returns()
|
||||
}))
|
||||
s.Run("UpdateWorkspaceAgentDisplayAppsByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
@@ -3996,6 +4041,20 @@ func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
dbm.EXPECT().GetWorkspaceAgentsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceAgent{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetChatsUpdatedAfter(gomock.Any(), ts).Return([]database.GetChatsUpdatedAfterRow{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatMessageSummariesPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetChatMessageSummariesPerChat(gomock.Any(), ts).Return([]database.GetChatMessageSummariesPerChatRow{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatModelConfigsForTelemetry", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatModelConfigsForTelemetry(gomock.Any()).Return([]database.GetChatModelConfigsForTelemetryRow{}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetWorkspaceAppsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetWorkspaceAppsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceApp{}, nil).AnyTimes()
|
||||
@@ -5383,10 +5442,10 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
Returns()
|
||||
Returns(int64(1))
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -1644,6 +1644,8 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
|
||||
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
|
||||
ClientSessionID: seed.ClientSessionID,
|
||||
CredentialKind: takeFirst(seed.CredentialKind, database.CredentialKindCentralized),
|
||||
CredentialHint: takeFirst(seed.CredentialHint, ""),
|
||||
})
|
||||
if endedAt != nil {
|
||||
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
|
||||
@@ -280,6 +280,14 @@ func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Conte
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("ClearChatMessageProviderResponseIDsByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ClearChatMessageProviderResponseIDsByChatID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
|
||||
@@ -592,6 +600,22 @@ func (m queryMetricsStore) DeleteOldAuditLogs(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldChatFiles(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldChatFiles").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChatFiles").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldChats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldConnectionLogs(ctx, arg)
|
||||
@@ -712,12 +736,12 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -952,6 +976,14 @@ func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActiveChatsByAgentID(ctx, agentID)
|
||||
m.queryLatencies.WithLabelValues("GetActiveChatsByAgentID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveChatsByAgentID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
|
||||
@@ -1160,6 +1192,14 @@ func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (da
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessageSummariesPerChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageSummariesPerChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
|
||||
@@ -1208,6 +1248,14 @@ func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigsForTelemetry(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigsForTelemetry").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigsForTelemetry").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByID(ctx, id)
|
||||
@@ -1240,6 +1288,14 @@ func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uui
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatRetentionDays(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatRetentionDays").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatRetentionDays").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatSystemPrompt(ctx)
|
||||
@@ -1312,6 +1368,14 @@ func (m queryMetricsStore) GetChatsByWorkspaceIDs(ctx context.Context, ids []uui
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatsUpdatedAfter(ctx, updatedAfter)
|
||||
m.queryLatencies.WithLabelValues("GetChatsUpdatedAfter").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsUpdatedAfter").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
|
||||
@@ -1928,11 +1992,11 @@ func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg databa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
func (m queryMetricsStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
|
||||
r0, r1 := m.s.GetPRInsightsPullRequests(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsPullRequests").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPullRequests").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4056,6 +4120,14 @@ func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, ar
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteContextFileMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteContextFileMessages").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
|
||||
@@ -4768,6 +4840,14 @@ func (m queryMetricsStore) UpdateWorkspaceAgentConnectionByID(ctx context.Contex
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateWorkspaceAgentDirectoryByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateWorkspaceAgentDirectoryByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateWorkspaceAgentDisplayAppsByID(ctx, arg)
|
||||
@@ -5000,6 +5080,14 @@ func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Cont
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatRetentionDays(ctx, retentionDays)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatRetentionDays").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatRetentionDays").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatSystemPrompt(ctx, value)
|
||||
|
||||
@@ -363,6 +363,20 @@ func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID mocks base method.
|
||||
func (m *MockStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClearChatMessageProviderResponseIDsByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID indicates an expected call of ClearChatMessageProviderResponseIDsByChatID.
|
||||
func (mr *MockStoreMockRecorder) ClearChatMessageProviderResponseIDsByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearChatMessageProviderResponseIDsByChatID", reflect.TypeOf((*MockStore)(nil).ClearChatMessageProviderResponseIDsByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -984,6 +998,36 @@ func (mr *MockStoreMockRecorder) DeleteOldAuditLogs(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAuditLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAuditLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldChatFiles mocks base method.
|
||||
func (m *MockStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldChatFiles", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldChatFiles indicates an expected call of DeleteOldChatFiles.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldChatFiles(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChatFiles", reflect.TypeOf((*MockStore)(nil).DeleteOldChatFiles), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldChats mocks base method.
|
||||
func (m *MockStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldChats", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldChats indicates an expected call of DeleteOldChats.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldChats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChats", reflect.TypeOf((*MockStore)(nil).DeleteOldChats), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldConnectionLogs mocks base method.
|
||||
func (m *MockStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1200,11 +1244,12 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
@@ -1637,6 +1682,21 @@ func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID mocks base method.
|
||||
func (m *MockStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveChatsByAgentID", ctx, agentID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID indicates an expected call of GetActiveChatsByAgentID.
|
||||
func (mr *MockStoreMockRecorder) GetActiveChatsByAgentID(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveChatsByAgentID", reflect.TypeOf((*MockStore)(nil).GetActiveChatsByAgentID), ctx, agentID)
|
||||
}
|
||||
|
||||
// GetActivePresetPrebuildSchedules mocks base method.
|
||||
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2132,6 +2192,21 @@ func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageByID", reflect.TypeOf((*MockStore)(nil).GetChatMessageByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatMessageSummariesPerChat mocks base method.
|
||||
func (m *MockStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessageSummariesPerChat", ctx, createdAfter)
|
||||
ret0, _ := ret[0].([]database.GetChatMessageSummariesPerChatRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessageSummariesPerChat indicates an expected call of GetChatMessageSummariesPerChat.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessageSummariesPerChat(ctx, createdAfter any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageSummariesPerChat", reflect.TypeOf((*MockStore)(nil).GetChatMessageSummariesPerChat), ctx, createdAfter)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2222,6 +2297,21 @@ func (mr *MockStoreMockRecorder) GetChatModelConfigs(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetChatModelConfigsForTelemetry mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigsForTelemetry", ctx)
|
||||
ret0, _ := ret[0].([]database.GetChatModelConfigsForTelemetryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigsForTelemetry indicates an expected call of GetChatModelConfigsForTelemetry.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigsForTelemetry(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigsForTelemetry", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigsForTelemetry), ctx)
|
||||
}
|
||||
|
||||
// GetChatProviderByID mocks base method.
|
||||
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2282,6 +2372,21 @@ func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatRetentionDays mocks base method.
|
||||
func (m *MockStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatRetentionDays", ctx)
|
||||
ret0, _ := ret[0].(int32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatRetentionDays indicates an expected call of GetChatRetentionDays.
|
||||
func (mr *MockStoreMockRecorder) GetChatRetentionDays(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatRetentionDays", reflect.TypeOf((*MockStore)(nil).GetChatRetentionDays), ctx)
|
||||
}
|
||||
|
||||
// GetChatSystemPrompt mocks base method.
|
||||
func (m *MockStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2417,6 +2522,21 @@ func (mr *MockStoreMockRecorder) GetChatsByWorkspaceIDs(ctx, ids any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByWorkspaceIDs", reflect.TypeOf((*MockStore)(nil).GetChatsByWorkspaceIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetChatsUpdatedAfter mocks base method.
|
||||
func (m *MockStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatsUpdatedAfter", ctx, updatedAfter)
|
||||
ret0, _ := ret[0].([]database.GetChatsUpdatedAfterRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatsUpdatedAfter indicates an expected call of GetChatsUpdatedAfter.
|
||||
func (mr *MockStoreMockRecorder) GetChatsUpdatedAfter(ctx, updatedAfter any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).GetChatsUpdatedAfter), ctx, updatedAfter)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3572,19 +3692,19 @@ func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs mocks base method.
|
||||
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
// GetPRInsightsPullRequests mocks base method.
|
||||
func (m *MockStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsPullRequests", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsPullRequestsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
|
||||
// GetPRInsightsPullRequests indicates an expected call of GetPRInsightsPullRequests.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsPullRequests(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPullRequests", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPullRequests), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary mocks base method.
|
||||
@@ -7690,6 +7810,20 @@ func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages mocks base method.
|
||||
func (m *MockStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteContextFileMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages indicates an expected call of SoftDeleteContextFileMessages.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// TryAcquireLock mocks base method.
|
||||
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8986,6 +9120,20 @@ func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentConnectionByID(ctx, arg any
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentConnectionByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentConnectionByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDirectoryByID mocks base method.
|
||||
func (m *MockStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateWorkspaceAgentDirectoryByID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDirectoryByID indicates an expected call of UpdateWorkspaceAgentDirectoryByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentDirectoryByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentDirectoryByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentDirectoryByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDisplayAppsByID mocks base method.
|
||||
func (m *MockStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9399,6 +9547,20 @@ func (mr *MockStoreMockRecorder) UpsertChatIncludeDefaultSystemPrompt(ctx, inclu
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatIncludeDefaultSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatIncludeDefaultSystemPrompt), ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
// UpsertChatRetentionDays mocks base method.
|
||||
func (m *MockStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatRetentionDays", ctx, retentionDays)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatRetentionDays indicates an expected call of UpsertChatRetentionDays.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatRetentionDays(ctx, retentionDays any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatRetentionDays", reflect.TypeOf((*MockStore)(nil).UpsertChatRetentionDays), ctx, retentionDays)
|
||||
}
|
||||
|
||||
// UpsertChatSystemPrompt mocks base method.
|
||||
func (m *MockStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -34,6 +34,11 @@ const (
|
||||
// long enough to cover the maximum interval of a heartbeat event (currently
|
||||
// 1 hour) plus some buffer.
|
||||
maxTelemetryHeartbeatAge = 24 * time.Hour
|
||||
// Batch sizes for chat purging. Both use 1000, which is smaller
|
||||
// than audit/connection log batches (10000), because chat_files
|
||||
// rows contain bytea blob data that make large batches heavier.
|
||||
chatsBatchSize = 1000
|
||||
chatFilesBatchSize = 1000
|
||||
)
|
||||
|
||||
// New creates a new periodically purging database instance.
|
||||
@@ -109,6 +114,17 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder
|
||||
// purgeTick performs a single purge iteration. It returns an error if the
|
||||
// purge fails.
|
||||
func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.Time) error {
|
||||
// Read chat retention config outside the transaction to
|
||||
// avoid poisoning the tx if the stored value is corrupt.
|
||||
// A SQL-level cast error (e.g. non-numeric text) puts PG
|
||||
// into error state, failing all subsequent queries in the
|
||||
// same transaction.
|
||||
chatRetentionDays, err := db.GetChatRetentionDays(ctx)
|
||||
if err != nil {
|
||||
i.logger.Warn(ctx, "failed to read chat retention config, skipping chat purge", slog.Error(err))
|
||||
chatRetentionDays = 0
|
||||
}
|
||||
|
||||
// Start a transaction to grab advisory lock, we don't want to run
|
||||
// multiple purges at the same time (multiple replicas).
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
@@ -213,12 +229,43 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
|
||||
}
|
||||
}
|
||||
|
||||
// Chat retention is configured via site_configs. When
|
||||
// enabled, old archived chats are deleted first, then
|
||||
// orphaned chat files. Deleting a chat cascades to
|
||||
// chat_file_links (removing references) but not to
|
||||
// chat_files directly, so files from deleted chats
|
||||
// become orphaned and are caught by DeleteOldChatFiles
|
||||
// in the same tick.
|
||||
var purgedChats int64
|
||||
var purgedChatFiles int64
|
||||
if chatRetentionDays > 0 {
|
||||
chatRetention := time.Duration(chatRetentionDays) * 24 * time.Hour
|
||||
deleteChatsBefore := start.Add(-chatRetention)
|
||||
|
||||
purgedChats, err = tx.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: deleteChatsBefore,
|
||||
LimitCount: chatsBatchSize,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to delete old chats: %w", err)
|
||||
}
|
||||
|
||||
purgedChatFiles, err = tx.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: deleteChatsBefore,
|
||||
LimitCount: chatFilesBatchSize,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to delete old chat files: %w", err)
|
||||
}
|
||||
}
|
||||
i.logger.Debug(ctx, "purged old database entries",
|
||||
slog.F("workspace_agent_logs", purgedWorkspaceAgentLogs),
|
||||
slog.F("expired_api_keys", expiredAPIKeys),
|
||||
slog.F("aibridge_records", purgedAIBridgeRecords),
|
||||
slog.F("connection_logs", purgedConnectionLogs),
|
||||
slog.F("audit_logs", purgedAuditLogs),
|
||||
slog.F("chats", purgedChats),
|
||||
slog.F("chat_files", purgedChatFiles),
|
||||
slog.F("duration", i.clk.Since(start)),
|
||||
)
|
||||
|
||||
@@ -232,6 +279,8 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
|
||||
i.recordsPurged.WithLabelValues("aibridge_records").Add(float64(purgedAIBridgeRecords))
|
||||
i.recordsPurged.WithLabelValues("connection_logs").Add(float64(purgedConnectionLogs))
|
||||
i.recordsPurged.WithLabelValues("audit_logs").Add(float64(purgedAuditLogs))
|
||||
i.recordsPurged.WithLabelValues("chats").Add(float64(purgedChats))
|
||||
i.recordsPurged.WithLabelValues("chat_files").Add(float64(purgedChatFiles))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -53,6 +54,7 @@ func TestPurge(t *testing.T) {
|
||||
clk := quartz.NewMock(t)
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
mDB := dbmock.NewMockStore(gomock.NewController(t))
|
||||
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
|
||||
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).Return(nil).Times(2)
|
||||
purger := dbpurge.New(context.Background(), testutil.Logger(t), mDB, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
<-done // wait for doTick() to run.
|
||||
@@ -125,6 +127,16 @@ func TestMetrics(t *testing.T) {
|
||||
"record_type": "audit_logs",
|
||||
})
|
||||
require.GreaterOrEqual(t, auditLogs, 0)
|
||||
|
||||
chats := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
|
||||
"record_type": "chats",
|
||||
})
|
||||
require.GreaterOrEqual(t, chats, 0)
|
||||
|
||||
chatFiles := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
|
||||
"record_type": "chat_files",
|
||||
})
|
||||
require.GreaterOrEqual(t, chatFiles, 0)
|
||||
})
|
||||
|
||||
t.Run("FailedIteration", func(t *testing.T) {
|
||||
@@ -138,6 +150,7 @@ func TestMetrics(t *testing.T) {
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
|
||||
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).
|
||||
Return(xerrors.New("simulated database error")).
|
||||
MinTimes(1)
|
||||
@@ -1634,3 +1647,488 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
//nolint:paralleltest // It uses LockIDDBPurge.
|
||||
func TestDeleteOldChatFiles(t *testing.T) {
|
||||
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// createChatFile inserts a chat file and backdates created_at.
|
||||
createChatFile := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, orgID uuid.UUID, createdAt time.Time) uuid.UUID {
|
||||
t.Helper()
|
||||
row, err := db.InsertChatFile(ctx, database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
Name: "test.png",
|
||||
Mimetype: "image/png",
|
||||
Data: []byte("fake-image-data"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chat_files SET created_at = $1 WHERE id = $2", createdAt, row.ID)
|
||||
require.NoError(t, err)
|
||||
return row.ID
|
||||
}
|
||||
|
||||
// createChat inserts a chat and optionally archives it, then
|
||||
// backdates updated_at to control the "archived since" window.
|
||||
createChat := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, modelConfigID uuid.UUID, archived bool, updatedAt time.Time) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelConfigID,
|
||||
Title: "test-chat",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if archived {
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2", updatedAt, chat.ID)
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
// setupChatDeps creates the common dependencies needed for
|
||||
// chat-related tests: user, org, org member, provider, model config.
|
||||
type chatDeps struct {
|
||||
user database.User
|
||||
org database.Organization
|
||||
modelConfig database.ChatModelConfig
|
||||
}
|
||||
setupChatDeps := func(ctx context.Context, t *testing.T, db database.Store) chatDeps {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
mc, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
ContextLimit: 8192,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chatDeps{user: user, org: org, modelConfig: mc}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "ChatRetentionDisabled",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Disable retention.
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(0))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an old archived chat and an orphaned old file.
|
||||
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
oldFileID := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
// Both should still exist.
|
||||
_, err = db.GetChatByID(ctx, oldChat.ID)
|
||||
require.NoError(t, err, "chat should not be deleted when retention is disabled")
|
||||
_, err = db.GetChatFileByID(ctx, oldFileID)
|
||||
require.NoError(t, err, "chat file should not be deleted when retention is disabled")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OldArchivedChatsDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Old archived chat (31 days) — should be deleted.
|
||||
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
// Insert a message so we can verify CASCADE.
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: oldChat.ID,
|
||||
CreatedBy: []uuid.UUID{deps.user.ID},
|
||||
ModelConfigID: []uuid.UUID{deps.modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
|
||||
Content: []string{`[{"type":"text","text":"hello"}]`},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Recently archived chat (10 days) — should be retained.
|
||||
recentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
|
||||
|
||||
// Active chat — should be retained.
|
||||
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
// Old archived chat should be gone.
|
||||
_, err = db.GetChatByID(ctx, oldChat.ID)
|
||||
require.Error(t, err, "old archived chat should be deleted")
|
||||
|
||||
// Its messages should be gone too (CASCADE).
|
||||
msgs, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: oldChat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, msgs, "messages should be cascade-deleted")
|
||||
|
||||
// Recent archived and active chats should remain.
|
||||
_, err = db.GetChatByID(ctx, recentChat.ID)
|
||||
require.NoError(t, err, "recently archived chat should be retained")
|
||||
_, err = db.GetChatByID(ctx, activeChat.ID)
|
||||
require.NoError(t, err, "active chat should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OrphanedOldFilesDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// File A: 31 days old, NOT in any chat -> should be deleted.
|
||||
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
|
||||
// File B: 31 days old, in an active chat -> should be retained.
|
||||
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: activeChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileB},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// File C: 10 days old, NOT in any chat -> should be retained (too young).
|
||||
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-10*24*time.Hour))
|
||||
|
||||
// File near boundary: 29d23h old — close to threshold.
|
||||
fileBoundary := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-30*24*time.Hour).Add(time.Hour))
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileA)
|
||||
require.Error(t, err, "orphaned old file A should be deleted")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileB)
|
||||
require.NoError(t, err, "file B in active chat should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileC)
|
||||
require.NoError(t, err, "young file C should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileBoundary)
|
||||
require.NoError(t, err, "file near 30d boundary should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ArchivedChatFilesDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// File D: 31 days old, in a chat archived 31 days ago -> should be deleted.
|
||||
fileD := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
oldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: oldArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileD},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// LinkChatFiles does not update chats.updated_at, so backdate.
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-31*24*time.Hour), oldArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// File E: 31 days old, in a chat archived 10 days ago -> should be retained.
|
||||
fileE := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
recentArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: recentArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileE},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-10*24*time.Hour), recentArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// File F: 31 days old, in BOTH an active chat AND an old archived chat -> should be retained.
|
||||
fileF := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
anotherOldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: anotherOldArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileF},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-31*24*time.Hour), anotherOldArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeChatForF := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: activeChatForF.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileF},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileD)
|
||||
require.Error(t, err, "file D in old archived chat should be deleted")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileE)
|
||||
require.NoError(t, err, "file E in recently archived chat should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileF)
|
||||
require.NoError(t, err, "file F in active + old archived chat should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UnarchiveAfterFilePurge",
|
||||
run: func(t *testing.T) {
|
||||
// Validates that when dbpurge deletes chat_files rows,
|
||||
// the FK cascade on chat_file_links automatically
|
||||
// removes the stale links. Unarchiving a chat after
|
||||
// file purge should show only surviving files.
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create a chat with three attached files.
|
||||
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
|
||||
chat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err := db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileA, fileB, fileC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the chat.
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate dbpurge deleting files A and B. The FK
|
||||
// cascade on chat_file_links_file_id_fkey should
|
||||
// automatically remove the corresponding link rows.
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)", pq.Array([]uuid.UUID{fileA, fileB}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive the chat.
|
||||
_, err = db.UnarchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Only file C should remain linked (FK cascade
|
||||
// removed the links for deleted files A and B).
|
||||
files, err := db.GetChatFileMetadataByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 1, "only surviving file should be linked")
|
||||
require.Equal(t, fileC, files[0].ID)
|
||||
|
||||
// Edge case: delete the last file too. The chat
|
||||
// should have zero linked files, not an error.
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = $1", fileC)
|
||||
require.NoError(t, err)
|
||||
_, err = db.UnarchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
files, err = db.GetChatFileMetadataByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, files, "all-files-deleted should yield empty result")
|
||||
|
||||
// Test parent+child cascade: deleting files should
|
||||
// clean up links for both parent and child chats
|
||||
// independently via FK cascade.
|
||||
parentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: deps.user.ID,
|
||||
LastModelConfigID: deps.modelConfig.ID,
|
||||
Title: "child-chat",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Set root_chat_id to link child to parent.
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET root_chat_id = $1 WHERE id = $2", parentChat.ID, childChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attach different files to parent and child.
|
||||
parentFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
parentFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
childFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
childFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: parentChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{parentFileKeep, parentFileStale},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: childChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{childFileKeep, childFileStale},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive via parent (cascades to child).
|
||||
_, err = db.ArchiveChatByID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete one file from each chat.
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)",
|
||||
pq.Array([]uuid.UUID{parentFileStale, childFileStale}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive via parent.
|
||||
_, err = db.UnarchiveChatByID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parentFiles, err := db.GetChatFileMetadataByChatID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parentFiles, 1)
|
||||
require.Equal(t, parentFileKeep, parentFiles[0].ID,
|
||||
"parent should retain only non-stale file")
|
||||
|
||||
childFiles, err := db.GetChatFileMetadataByChatID(ctx, childChat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, childFiles, 1)
|
||||
require.Equal(t, childFileKeep, childFiles[0].ID,
|
||||
"child should retain only non-stale file")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BatchLimitFiles",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create 3 deletable orphaned files (all 31 days old).
|
||||
for range 3 {
|
||||
createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
}
|
||||
|
||||
// Delete with limit 2 — should delete 2, leave 1.
|
||||
deleted, err := db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted, "should delete exactly 2 files")
|
||||
|
||||
// Delete again — should delete the remaining 1.
|
||||
deleted, err = db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), deleted, "should delete remaining 1 file")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BatchLimitChats",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create 3 deletable old archived chats.
|
||||
for range 3 {
|
||||
createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
}
|
||||
|
||||
// Delete with limit 2 — should delete 2, leave 1.
|
||||
deleted, err := db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted, "should delete exactly 2 chats")
|
||||
|
||||
// Delete again — should delete the remaining 1.
|
||||
deleted, err = db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), deleted, "should delete remaining 1 chat")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.run(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Generated
+18
-5
@@ -293,7 +293,8 @@ CREATE TYPE chat_status AS ENUM (
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
'error',
|
||||
'requires_action'
|
||||
);
|
||||
|
||||
CREATE TYPE connection_status AS ENUM (
|
||||
@@ -315,6 +316,11 @@ CREATE TYPE cors_behavior AS ENUM (
|
||||
'passthru'
|
||||
);
|
||||
|
||||
CREATE TYPE credential_kind AS ENUM (
|
||||
'centralized',
|
||||
'byok'
|
||||
);
|
||||
|
||||
CREATE TYPE crypto_key_feature AS ENUM (
|
||||
'workspace_apps_token',
|
||||
'workspace_apps_api_key',
|
||||
@@ -1101,7 +1107,9 @@ CREATE TABLE aibridge_interceptions (
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256),
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL,
|
||||
provider_name text DEFAULT ''::text NOT NULL
|
||||
provider_name text DEFAULT ''::text NOT NULL,
|
||||
credential_kind credential_kind DEFAULT 'centralized'::credential_kind NOT NULL,
|
||||
credential_hint character varying(15) DEFAULT ''::character varying NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1118,6 +1126,10 @@ COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related intercept
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
@@ -1418,7 +1430,8 @@ CREATE TABLE chats (
|
||||
agent_id uuid,
|
||||
pin_order integer DEFAULT 0 NOT NULL,
|
||||
last_read_message_id bigint,
|
||||
last_injected_context jsonb
|
||||
last_injected_context jsonb,
|
||||
dynamic_tools jsonb
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -3770,14 +3783,14 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chats_labels ON chats USING gin (labels);
|
||||
|
||||
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner_updated_id ON chats USING btree (owner_id, updated_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_chats_parent_chat_id ON chats USING btree (parent_chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
-- First update any rows using the value we're about to remove.
|
||||
-- The column type is still the original chat_status at this point.
|
||||
UPDATE chats SET status = 'error' WHERE status = 'requires_action';
|
||||
|
||||
-- Drop the column (this is independent of the enum).
|
||||
ALTER TABLE chats DROP COLUMN IF EXISTS dynamic_tools;
|
||||
|
||||
-- Drop the partial index that references the chat_status enum type.
|
||||
-- It must be removed before the rename-create-cast-drop cycle
|
||||
-- because the index's WHERE clause (status = 'pending'::chat_status)
|
||||
-- would otherwise cause a cross-type comparison failure.
|
||||
DROP INDEX IF EXISTS idx_chats_pending;
|
||||
|
||||
-- Now recreate the enum without requires_action.
|
||||
-- We must use the rename-create-cast-drop pattern.
|
||||
ALTER TYPE chat_status RENAME TO chat_status_old;
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting',
|
||||
'pending',
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
);
|
||||
ALTER TABLE chats ALTER COLUMN status DROP DEFAULT;
|
||||
ALTER TABLE chats ALTER COLUMN status TYPE chat_status USING status::text::chat_status;
|
||||
ALTER TABLE chats ALTER COLUMN status SET DEFAULT 'waiting';
|
||||
DROP TYPE chat_status_old;
|
||||
|
||||
-- Recreate the partial index.
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'requires_action';
|
||||
|
||||
ALTER TABLE chats ADD COLUMN dynamic_tools JSONB DEFAULT NULL;
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE aibridge_interceptions
|
||||
DROP COLUMN IF EXISTS credential_kind,
|
||||
DROP COLUMN IF EXISTS credential_hint;
|
||||
|
||||
DROP TYPE IF EXISTS credential_kind;
|
||||
@@ -0,0 +1,12 @@
|
||||
CREATE TYPE credential_kind AS ENUM ('centralized', 'byok');
|
||||
|
||||
-- Records how each LLM request was authenticated and a masked credential
|
||||
-- identifier for audit purposes. Existing rows default to 'centralized'
|
||||
-- with an empty hint since we cannot retroactively determine their values.
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN credential_kind credential_kind NOT NULL DEFAULT 'centralized',
|
||||
-- Length capped as a safety measure to ensure only masked values are stored.
|
||||
ADD COLUMN credential_hint CHARACTER VARYING(15) NOT NULL DEFAULT '';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
|
||||
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS idx_chats_agent_id;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_chats_agent_id ON chats(agent_id) WHERE agent_id IS NOT NULL;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_chats_owner_updated_id ON chats (owner_id, updated_at DESC, id DESC);
|
||||
@@ -0,0 +1,5 @@
|
||||
-- The GetChats ORDER BY changed from (updated_at, id) DESC to a 4-column
|
||||
-- expression sort (pinned-first flag, negated pin_order, updated_at, id).
|
||||
-- This index was purpose-built for the old sort and no longer provides
|
||||
-- read benefit. The simpler idx_chats_owner covers the owner_id filter.
|
||||
DROP INDEX IF EXISTS idx_chats_owner_updated_id;
|
||||
@@ -798,6 +798,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.Chat.PinOrder,
|
||||
&i.Chat.LastReadMessageID,
|
||||
&i.Chat.LastInjectedContext,
|
||||
&i.Chat.DynamicTools,
|
||||
&i.HasUnread); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -868,6 +869,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.AIBridgeInterception.CredentialKind,
|
||||
&i.AIBridgeInterception.CredentialHint,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -1131,6 +1134,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, a
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.AIBridgeInterception.CredentialKind,
|
||||
&i.AIBridgeInterception.CredentialHint,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1290,12 +1290,13 @@ func AllChatModeValues() []ChatMode {
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
ChatStatusRequiresAction ChatStatus = "requires_action"
|
||||
)
|
||||
|
||||
func (e *ChatStatus) Scan(src interface{}) error {
|
||||
@@ -1340,7 +1341,8 @@ func (e ChatStatus) Valid() bool {
|
||||
ChatStatusRunning,
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError:
|
||||
ChatStatusError,
|
||||
ChatStatusRequiresAction:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -1354,6 +1356,7 @@ func AllChatStatusValues() []ChatStatus {
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError,
|
||||
ChatStatusRequiresAction,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1543,6 +1546,64 @@ func AllCorsBehaviorValues() []CorsBehavior {
|
||||
}
|
||||
}
|
||||
|
||||
type CredentialKind string
|
||||
|
||||
const (
|
||||
CredentialKindCentralized CredentialKind = "centralized"
|
||||
CredentialKindByok CredentialKind = "byok"
|
||||
)
|
||||
|
||||
func (e *CredentialKind) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = CredentialKind(s)
|
||||
case string:
|
||||
*e = CredentialKind(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for CredentialKind: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullCredentialKind struct {
|
||||
CredentialKind CredentialKind `json:"credential_kind"`
|
||||
Valid bool `json:"valid"` // Valid is true if CredentialKind is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullCredentialKind) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.CredentialKind, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.CredentialKind.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullCredentialKind) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.CredentialKind), nil
|
||||
}
|
||||
|
||||
func (e CredentialKind) Valid() bool {
|
||||
switch e {
|
||||
case CredentialKindCentralized,
|
||||
CredentialKindByok:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllCredentialKindValues() []CredentialKind {
|
||||
return []CredentialKind{
|
||||
CredentialKindCentralized,
|
||||
CredentialKindByok,
|
||||
}
|
||||
}
|
||||
|
||||
type CryptoKeyFeature string
|
||||
|
||||
const (
|
||||
@@ -4040,6 +4101,10 @@ type AIBridgeInterception struct {
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
// The provider instance name which may differ from provider when multiple instances of the same provider type exist.
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
// How the request was authenticated: centralized or byok.
|
||||
CredentialKind CredentialKind `db:"credential_kind" json:"credential_kind"`
|
||||
// Masked credential identifier for audit (e.g. sk-a***efgh).
|
||||
CredentialHint string `db:"credential_hint" json:"credential_hint"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
@@ -4180,6 +4245,7 @@ type Chat struct {
|
||||
PinOrder int32 `db:"pin_order" json:"pin_order"`
|
||||
LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"`
|
||||
LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"`
|
||||
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
|
||||
@@ -76,6 +76,7 @@ type sqlcQuerier interface {
|
||||
CleanTailnetLostPeers(ctx context.Context) error
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
@@ -128,6 +129,22 @@ type sqlcQuerier interface {
|
||||
// connection events (connect, disconnect, open, close) which are handled
|
||||
// separately by DeleteOldAuditLogConnectionEvents.
|
||||
DeleteOldAuditLogs(ctx context.Context, arg DeleteOldAuditLogsParams) (int64, error)
|
||||
// TODO(cian): Add indexes on chats(archived, updated_at) and
|
||||
// chat_files(created_at) for purge query performance.
|
||||
// See: https://github.com/coder/internal/issues/1438
|
||||
// Deletes chat files that are older than the given threshold and are
|
||||
// not referenced by any chat that is still active or was archived
|
||||
// within the same threshold window. This covers two cases:
|
||||
// 1. Orphaned files not linked to any chat.
|
||||
// 2. Files whose every referencing chat has been archived for longer
|
||||
// than the retention period.
|
||||
DeleteOldChatFiles(ctx context.Context, arg DeleteOldChatFilesParams) (int64, error)
|
||||
// Deletes chats that have been archived for longer than the given
|
||||
// threshold. Active (non-archived) chats are never deleted.
|
||||
// Related chat_messages, chat_diff_statuses, and
|
||||
// chat_queued_messages are removed via ON DELETE CASCADE.
|
||||
// Parent/root references on child chats are SET NULL.
|
||||
DeleteOldChats(ctx context.Context, arg DeleteOldChatsParams) (int64, error)
|
||||
DeleteOldConnectionLogs(ctx context.Context, arg DeleteOldConnectionLogsParams) (int64, error)
|
||||
// Delete all notification messages which have not been updated for over a week.
|
||||
DeleteOldNotificationMessages(ctx context.Context) error
|
||||
@@ -152,7 +169,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (int64, error)
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -199,6 +216,7 @@ type sqlcQuerier interface {
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActiveAISeatCount(ctx context.Context) (int64, error)
|
||||
GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
@@ -255,16 +273,27 @@ type sqlcQuerier interface {
|
||||
// otherwise the setting defaults to true.
|
||||
GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
// Aggregates message-level metrics per chat for messages created
|
||||
// after the given timestamp. Uses message created_at so that
|
||||
// ongoing activity in long-running chats is captured each window.
|
||||
GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]GetChatMessageSummariesPerChatRow, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
// Returns all model configurations for telemetry snapshot collection.
|
||||
GetChatModelConfigsForTelemetry(ctx context.Context) ([]GetChatModelConfigsForTelemetryRow, error)
|
||||
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
|
||||
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
// Returns the chat retention period in days. Chats archived longer
|
||||
// than this and orphaned chat files older than this are purged by
|
||||
// dbpurge. Returns 30 (days) when no value has been configured.
|
||||
// A value of 0 disables chat purging entirely.
|
||||
GetChatRetentionDays(ctx context.Context) (int32, error)
|
||||
GetChatSystemPrompt(ctx context.Context) (string, error)
|
||||
// GetChatSystemPromptConfig returns both chat system prompt settings in a
|
||||
// single read to avoid torn reads between separate site-config lookups.
|
||||
@@ -283,6 +312,10 @@ type sqlcQuerier interface {
|
||||
GetChatWorkspaceTTL(ctx context.Context) (string, error)
|
||||
GetChats(ctx context.Context, arg GetChatsParams) ([]GetChatsRow, error)
|
||||
GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]Chat, error)
|
||||
// Retrieves chats updated after the given timestamp for telemetry
|
||||
// snapshot collection. Uses updated_at so that long-running chats
|
||||
// still appear in each snapshot window while they are active.
|
||||
GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]GetChatsUpdatedAfterRow, error)
|
||||
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
|
||||
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
|
||||
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
|
||||
@@ -385,11 +418,12 @@ type sqlcQuerier interface {
|
||||
// per PR for state/additions/deletions/model (model comes from the
|
||||
// most recent chat).
|
||||
GetPRInsightsPerModel(ctx context.Context, arg GetPRInsightsPerModelParams) ([]GetPRInsightsPerModelRow, error)
|
||||
// Returns individual PR rows with cost for the recent PRs table.
|
||||
// Returns all individual PR rows with cost for the selected time range.
|
||||
// Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
|
||||
// direct children (that lack their own PR), and deduped picks one row
|
||||
// per PR for metadata.
|
||||
GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error)
|
||||
// per PR for metadata. A safety-cap LIMIT guards against unexpectedly
|
||||
// large result sets from direct API callers.
|
||||
GetPRInsightsPullRequests(ctx context.Context, arg GetPRInsightsPullRequestsParams) ([]GetPRInsightsPullRequestsRow, error)
|
||||
// PR Insights queries for the /agents analytics dashboard.
|
||||
// These aggregate data from chat_diff_statuses (PR metadata) joined
|
||||
// with chats and chat_messages (cost) to power the PR Insights view.
|
||||
@@ -478,8 +512,10 @@ type sqlcQuerier interface {
|
||||
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
|
||||
GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error)
|
||||
GetRuntimeConfig(ctx context.Context, key string) (string, error)
|
||||
// Find chats that appear stuck (running but heartbeat has expired).
|
||||
// Used for recovery after coderd crashes or long hangs.
|
||||
// Find chats that appear stuck and need recovery. This covers:
|
||||
// 1. Running chats whose heartbeat has expired (worker crash).
|
||||
// 2. Chats awaiting client action (requires_action) past the
|
||||
// timeout threshold (client disappeared).
|
||||
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
|
||||
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
|
||||
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
|
||||
@@ -860,11 +896,16 @@ type sqlcQuerier interface {
|
||||
SelectUsageEventsForPublishing(ctx context.Context, now time.Time) ([]UsageEvent, error)
|
||||
SoftDeleteChatMessageByID(ctx context.Context, id int64) error
|
||||
SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error
|
||||
SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error
|
||||
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
||||
//
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
// released when the transaction ends.
|
||||
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
|
||||
// Unarchives a chat (and its children). Stale file references are
|
||||
// handled automatically by FK cascades on chat_file_links: when
|
||||
// dbpurge deletes a chat_files row, the corresponding
|
||||
// chat_file_links rows are cascade-deleted by PostgreSQL.
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
@@ -971,6 +1012,7 @@ type sqlcQuerier interface {
|
||||
UpdateWorkspace(ctx context.Context, arg UpdateWorkspaceParams) (WorkspaceTable, error)
|
||||
UpdateWorkspaceACLByID(ctx context.Context, arg UpdateWorkspaceACLByIDParams) error
|
||||
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
|
||||
UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg UpdateWorkspaceAgentDirectoryByIDParams) error
|
||||
UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg UpdateWorkspaceAgentDisplayAppsByIDParams) error
|
||||
UpdateWorkspaceAgentLifecycleStateByID(ctx context.Context, arg UpdateWorkspaceAgentLifecycleStateByIDParams) error
|
||||
UpdateWorkspaceAgentLogOverflowByID(ctx context.Context, arg UpdateWorkspaceAgentLogOverflowByIDParams) error
|
||||
@@ -1006,6 +1048,7 @@ type sqlcQuerier interface {
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error
|
||||
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
|
||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
|
||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||
|
||||
@@ -7376,7 +7376,7 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
||||
|
||||
// 6. DELETE
|
||||
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
_, err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
})
|
||||
@@ -9085,10 +9085,11 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
||||
|
||||
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
|
||||
insertParams := database.InsertAIBridgeInterceptionParams{
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
Client: sql.NullString{String: "client", Valid: true},
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
Client: sql.NullString{String: "client", Valid: true},
|
||||
CredentialKind: database.CredentialKindCentralized,
|
||||
}
|
||||
|
||||
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
|
||||
@@ -10407,11 +10408,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10441,11 +10441,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
||||
|
||||
// RecentPRs ordered by created_at DESC: chatB is newer.
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10490,11 +10489,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10532,11 +10530,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(9_000_000), summary.TotalCostMicros)
|
||||
|
||||
// RecentPRs should return 1 row with the full tree cost.
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10574,11 +10571,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10620,11 +10616,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(17_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10657,11 +10652,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(10_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10694,11 +10688,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(15_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10723,11 +10716,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(0), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10766,11 +10758,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
require.Len(t, byModel, 1)
|
||||
assert.Equal(t, modelName, byModel[0].DisplayName)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10802,6 +10793,30 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
||||
assert.Equal(t, int64(5_000_000), summary.MergedCostMicros)
|
||||
})
|
||||
|
||||
t.Run("AllPRsReturnedWithSafetyCap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID := setupChatInfra(t)
|
||||
|
||||
// Create 25 distinct PRs — more than the old LIMIT 20 — and
|
||||
// verify all are returned.
|
||||
const prCount = 25
|
||||
for i := range prCount {
|
||||
chat := createChat(t, store, userID, mcID, fmt.Sprintf("chat-%d", i))
|
||||
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
|
||||
linkPR(t, store, chat.ID,
|
||||
fmt.Sprintf("https://github.com/org/repo/pull/%d", 100+i),
|
||||
"merged", fmt.Sprintf("fix: pr-%d", i), 10, 2, 1)
|
||||
}
|
||||
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, recent, prCount, "all PRs within the date range should be returned")
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatPinOrderQueries(t *testing.T) {
|
||||
|
||||
+550
-89
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id, credential_kind, credential_hint
|
||||
) VALUES (
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid, @credential_kind, @credential_hint
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -18,3 +18,37 @@ FROM chat_files cf
|
||||
JOIN chat_file_links cfl ON cfl.file_id = cf.id
|
||||
WHERE cfl.chat_id = @chat_id::uuid
|
||||
ORDER BY cf.created_at ASC;
|
||||
|
||||
-- TODO(cian): Add indexes on chats(archived, updated_at) and
|
||||
-- chat_files(created_at) for purge query performance.
|
||||
-- See: https://github.com/coder/internal/issues/1438
|
||||
-- name: DeleteOldChatFiles :execrows
|
||||
-- Deletes chat files that are older than the given threshold and are
|
||||
-- not referenced by any chat that is still active or was archived
|
||||
-- within the same threshold window. This covers two cases:
|
||||
-- 1. Orphaned files not linked to any chat.
|
||||
-- 2. Files whose every referencing chat has been archived for longer
|
||||
-- than the retention period.
|
||||
WITH kept_file_ids AS (
|
||||
-- NOTE: This uses updated_at as a proxy for archive time
|
||||
-- because there is no archived_at column. Correctness
|
||||
-- requires that updated_at is never backdated on archived
|
||||
-- chats. See ArchiveChatByID.
|
||||
SELECT DISTINCT cfl.file_id
|
||||
FROM chat_file_links cfl
|
||||
JOIN chats c ON c.id = cfl.chat_id
|
||||
WHERE c.archived = false
|
||||
OR c.updated_at >= @before_time::timestamptz
|
||||
),
|
||||
deletable AS (
|
||||
SELECT cf.id
|
||||
FROM chat_files cf
|
||||
LEFT JOIN kept_file_ids k ON cf.id = k.file_id
|
||||
WHERE cf.created_at < @before_time::timestamptz
|
||||
AND k.file_id IS NULL
|
||||
ORDER BY cf.created_at ASC
|
||||
LIMIT @limit_count
|
||||
)
|
||||
DELETE FROM chat_files
|
||||
USING deletable
|
||||
WHERE chat_files.id = deletable.id;
|
||||
|
||||
@@ -173,11 +173,12 @@ JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
GROUP BY d.model_config_id, d.display_name, d.model, d.provider
|
||||
ORDER BY total_prs DESC;
|
||||
|
||||
-- name: GetPRInsightsRecentPRs :many
|
||||
-- Returns individual PR rows with cost for the recent PRs table.
|
||||
-- name: GetPRInsightsPullRequests :many
|
||||
-- Returns all individual PR rows with cost for the selected time range.
|
||||
-- Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
|
||||
-- direct children (that lack their own PR), and deduped picks one row
|
||||
-- per PR for metadata.
|
||||
-- per PR for metadata. A safety-cap LIMIT guards against unexpectedly
|
||||
-- large result sets from direct API callers.
|
||||
WITH pr_costs AS (
|
||||
SELECT
|
||||
prc.pr_key,
|
||||
@@ -264,4 +265,4 @@ SELECT * FROM (
|
||||
JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
) sub
|
||||
ORDER BY sub.created_at DESC
|
||||
LIMIT @limit_val::int;
|
||||
LIMIT 500;
|
||||
|
||||
@@ -10,9 +10,14 @@ FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: UnarchiveChatByID :many
|
||||
-- Unarchives a chat (and its children). Stale file references are
|
||||
-- handled automatically by FK cascades on chat_file_links: when
|
||||
-- dbpurge deletes a chat_files row, the corresponding
|
||||
-- chat_file_links rows are cascade-deleted by PostgreSQL.
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
UPDATE chats SET
|
||||
archived = false,
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
@@ -348,20 +353,18 @@ WHERE
|
||||
ELSE chats.archived = sqlc.narg('archived') :: boolean
|
||||
END
|
||||
AND CASE
|
||||
-- This allows using the last element on a page as effectively a cursor.
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
-- Cursor pagination: the last element on a page acts as the cursor.
|
||||
-- The 4-tuple matches the ORDER BY below. All columns sort DESC
|
||||
-- (pin_order is negated so lower values sort first in DESC order),
|
||||
-- which lets us use a single tuple < comparison.
|
||||
WHEN @after_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the updated_at field, so select all
|
||||
-- rows before the cursor.
|
||||
(updated_at, id) < (
|
||||
(CASE WHEN pin_order > 0 THEN 1 ELSE 0 END, -pin_order, updated_at, id) < (
|
||||
SELECT
|
||||
updated_at, id
|
||||
CASE WHEN c2.pin_order > 0 THEN 1 ELSE 0 END, -c2.pin_order, c2.updated_at, c2.id
|
||||
FROM
|
||||
chats
|
||||
chats c2
|
||||
WHERE
|
||||
id = @after_id
|
||||
c2.id = @after_id
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
@@ -373,9 +376,15 @@ WHERE
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedChats
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(updated_at, id) DESC OFFSET @offset_opt
|
||||
-- Pinned chats (pin_order > 0) sort before unpinned ones. Within
|
||||
-- pinned chats, lower pin_order values come first. The negation
|
||||
-- trick (-pin_order) keeps all sort columns DESC so the cursor
|
||||
-- tuple < comparison works with uniform direction.
|
||||
CASE WHEN pin_order > 0 THEN 1 ELSE 0 END DESC,
|
||||
-pin_order DESC,
|
||||
updated_at DESC,
|
||||
id DESC
|
||||
OFFSET @offset_opt
|
||||
LIMIT
|
||||
-- The chat list is unbounded and expected to grow large.
|
||||
-- Default to 50 to prevent accidental excessively large queries.
|
||||
@@ -394,7 +403,8 @@ INSERT INTO chats (
|
||||
mode,
|
||||
status,
|
||||
mcp_server_ids,
|
||||
labels
|
||||
labels,
|
||||
dynamic_tools
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
@@ -407,7 +417,8 @@ INSERT INTO chats (
|
||||
sqlc.narg('mode')::chat_mode,
|
||||
@status::chat_status,
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb)
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb),
|
||||
sqlc.narg('dynamic_tools')::jsonb
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -664,15 +675,19 @@ RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetStaleChats :many
|
||||
-- Find chats that appear stuck (running but heartbeat has expired).
|
||||
-- Used for recovery after coderd crashes or long hangs.
|
||||
-- Find chats that appear stuck and need recovery. This covers:
|
||||
-- 1. Running chats whose heartbeat has expired (worker crash).
|
||||
-- 2. Chats awaiting client action (requires_action) past the
|
||||
-- timeout threshold (client disappeared).
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
(status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz)
|
||||
OR (status = 'requires_action'::chat_status
|
||||
AND updated_at < @stale_threshold::timestamptz);
|
||||
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
@@ -1220,3 +1235,88 @@ LIMIT 1;
|
||||
UPDATE chats
|
||||
SET last_read_message_id = @last_read_message_id::bigint
|
||||
WHERE id = @id::uuid;
|
||||
|
||||
-- name: DeleteOldChats :execrows
|
||||
-- Deletes chats that have been archived for longer than the given
|
||||
-- threshold. Active (non-archived) chats are never deleted.
|
||||
-- Related chat_messages, chat_diff_statuses, and
|
||||
-- chat_queued_messages are removed via ON DELETE CASCADE.
|
||||
-- Parent/root references on child chats are SET NULL.
|
||||
WITH deletable AS (
|
||||
SELECT id
|
||||
FROM chats
|
||||
WHERE archived = true
|
||||
AND updated_at < @before_time::timestamptz
|
||||
ORDER BY updated_at ASC
|
||||
LIMIT @limit_count
|
||||
)
|
||||
DELETE FROM chats
|
||||
USING deletable
|
||||
WHERE chats.id = deletable.id
|
||||
AND chats.archived = true;
|
||||
|
||||
-- name: GetChatsUpdatedAfter :many
|
||||
-- Retrieves chats updated after the given timestamp for telemetry
|
||||
-- snapshot collection. Uses updated_at so that long-running chats
|
||||
-- still appear in each snapshot window while they are active.
|
||||
SELECT
|
||||
id, owner_id, created_at, updated_at, status,
|
||||
(parent_chat_id IS NOT NULL)::bool AS has_parent,
|
||||
root_chat_id, workspace_id,
|
||||
mode, archived, last_model_config_id
|
||||
FROM chats
|
||||
WHERE updated_at > @updated_after;
|
||||
|
||||
-- name: GetChatMessageSummariesPerChat :many
|
||||
-- Aggregates message-level metrics per chat for messages created
|
||||
-- after the given timestamp. Uses message created_at so that
|
||||
-- ongoing activity in long-running chats is captured each window.
|
||||
SELECT
|
||||
cm.chat_id,
|
||||
COUNT(*)::bigint AS message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'user')::bigint AS user_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'assistant')::bigint AS assistant_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'tool')::bigint AS tool_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'system')::bigint AS system_message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.reasoning_tokens), 0)::bigint AS total_reasoning_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms,
|
||||
COUNT(DISTINCT cm.model_config_id)::bigint AS distinct_model_count,
|
||||
COUNT(*) FILTER (WHERE cm.compressed)::bigint AS compressed_message_count
|
||||
FROM chat_messages cm
|
||||
WHERE cm.created_at > @created_after
|
||||
AND cm.deleted = false
|
||||
GROUP BY cm.chat_id;
|
||||
|
||||
-- name: GetChatModelConfigsForTelemetry :many
|
||||
-- Returns all model configurations for telemetry snapshot collection.
|
||||
SELECT id, provider, model, context_limit, enabled, is_default
|
||||
FROM chat_model_configs
|
||||
WHERE deleted = false;
|
||||
-- name: GetActiveChatsByAgentID :many
|
||||
SELECT *
|
||||
FROM chats
|
||||
WHERE agent_id = @agent_id::uuid
|
||||
AND archived = false
|
||||
-- Active statuses only: waiting, pending, running, paused,
|
||||
-- requires_action.
|
||||
-- Excludes completed and error (terminal states).
|
||||
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
|
||||
ORDER BY updated_at DESC;
|
||||
|
||||
-- name: ClearChatMessageProviderResponseIDsByChatID :exec
|
||||
UPDATE chat_messages
|
||||
SET provider_response_id = NULL
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND provider_response_id IS NOT NULL;
|
||||
|
||||
-- name: SoftDeleteContextFileMessages :exec
|
||||
UPDATE chat_messages SET deleted = true
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND content::jsonb @> '[{"type": "context-file"}]';
|
||||
|
||||
@@ -195,7 +195,8 @@ SELECT
|
||||
w.id AS workspace_id,
|
||||
COALESCE(w.name, '') AS workspace_name,
|
||||
-- Include the name of the provisioner_daemon associated to the job
|
||||
COALESCE(pd.name, '') AS worker_name
|
||||
COALESCE(pd.name, '') AS worker_name,
|
||||
wb.transition as workspace_build_transition
|
||||
FROM
|
||||
provisioner_jobs pj
|
||||
LEFT JOIN
|
||||
@@ -240,7 +241,8 @@ GROUP BY
|
||||
t.icon,
|
||||
w.id,
|
||||
w.name,
|
||||
pd.name
|
||||
pd.name,
|
||||
wb.transition
|
||||
ORDER BY
|
||||
pj.created_at DESC
|
||||
LIMIT
|
||||
|
||||
@@ -236,3 +236,20 @@ VALUES ('agents_workspace_ttl', @workspace_ttl::text)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = @workspace_ttl::text
|
||||
WHERE site_configs.key = 'agents_workspace_ttl';
|
||||
|
||||
-- name: GetChatRetentionDays :one
|
||||
-- Returns the chat retention period in days. Chats archived longer
|
||||
-- than this and orphaned chat files older than this are purged by
|
||||
-- dbpurge. Returns 30 (days) when no value has been configured.
|
||||
-- A value of 0 disables chat purging entirely.
|
||||
SELECT COALESCE(
|
||||
(SELECT value::integer FROM site_configs
|
||||
WHERE key = 'agents_chat_retention_days'),
|
||||
30
|
||||
) :: integer AS retention_days;
|
||||
|
||||
-- name: UpsertChatRetentionDays :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES ('agents_chat_retention_days', CAST(@retention_days AS integer)::text)
|
||||
ON CONFLICT (key) DO UPDATE SET value = CAST(@retention_days AS integer)::text
|
||||
WHERE site_configs.key = 'agents_chat_retention_days';
|
||||
|
||||
@@ -56,6 +56,6 @@ SET
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
-- name: DeleteUserSecretByUserIDAndName :execrows
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
@@ -190,6 +190,14 @@ SET
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateWorkspaceAgentDirectoryByID :exec
|
||||
UPDATE
|
||||
workspace_agents
|
||||
SET
|
||||
directory = $2, updated_at = $3
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetWorkspaceAgentLogsAfter :many
|
||||
SELECT
|
||||
*
|
||||
|
||||
+257
-77
@@ -137,8 +137,9 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.
|
||||
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
logger := api.Logger.Named("chat_watcher")
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat watch stream.",
|
||||
@@ -146,54 +147,44 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatEvent(
|
||||
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// The encoder is only written from the SubscribeWithErr callback,
|
||||
// which delivers serially per subscription. Do not add a second
|
||||
// write path without introducing synchronization.
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatWatchEvent(
|
||||
func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
|
||||
logger.Error(ctx, "chat watch event subscription error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: payload,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err))
|
||||
if err := encoder.Encode(payload); err != nil {
|
||||
logger.Debug(ctx, "failed to send chat watch event", slog.Error(err))
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Internal error subscribing to chat events.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err))
|
||||
}
|
||||
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.")
|
||||
return
|
||||
}
|
||||
defer cancelSubscribe()
|
||||
|
||||
// Send initial ping to signal the connection is ready.
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypePing,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
|
||||
@@ -398,6 +389,10 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Cap the raw request body to prevent excessive memory use
|
||||
// from large dynamic tool schemas.
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
||||
|
||||
var req codersdk.CreateChatRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
@@ -488,6 +483,50 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UnsafeDynamicTools) > 250 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Too many dynamic tools.",
|
||||
Detail: "Maximum 250 dynamic tools per chat.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that dynamic tool names are non-empty and unique
|
||||
// within the list. Name collision with built-in tools is
|
||||
// checked at chatloop time when the full tool set is known.
|
||||
if len(req.UnsafeDynamicTools) > 0 {
|
||||
seenNames := make(map[string]struct{}, len(req.UnsafeDynamicTools))
|
||||
for _, dt := range req.UnsafeDynamicTools {
|
||||
if dt.Name == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Dynamic tool name must not be empty.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if _, exists := seenNames[dt.Name]; exists {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Duplicate dynamic tool name.",
|
||||
Detail: fmt.Sprintf("Tool %q appears more than once.", dt.Name),
|
||||
})
|
||||
return
|
||||
}
|
||||
seenNames[dt.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var dynamicToolsJSON json.RawMessage
|
||||
if len(req.UnsafeDynamicTools) > 0 {
|
||||
var err error
|
||||
dynamicToolsJSON, err = json.Marshal(req.UnsafeDynamicTools)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal dynamic tools.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: apiKey.UserID,
|
||||
WorkspaceID: workspaceSelection.WorkspaceID,
|
||||
@@ -497,6 +536,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
InitialUserContent: contentBlocks,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
if err != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
@@ -1770,9 +1810,9 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
// - pinOrder > 0 && already pinned: reorder (shift
|
||||
// neighbors, clamp to [1, count]).
|
||||
// - pinOrder > 0 && not pinned: append to end. The
|
||||
// requested value is intentionally ignored because
|
||||
// PinChatByID also bumps updated_at to keep the
|
||||
// chat visible in the paginated sidebar.
|
||||
// requested value is intentionally ignored; the
|
||||
// SQL ORDER BY sorts pinned chats first so they
|
||||
// appear on page 1 of the paginated sidebar.
|
||||
var err error
|
||||
errMsg := "Failed to pin chat."
|
||||
switch {
|
||||
@@ -2127,6 +2167,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -2149,7 +2190,22 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
// Subscribe before accepting the WebSocket so that failures
|
||||
// can still be reported as normal HTTP errors.
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
// Subscribe only fails today when the receiver is nil, which
|
||||
// the chatDaemon == nil guard above already catches. This is
|
||||
// defensive against future Subscribe failure modes.
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer cancelSub()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
@@ -2157,41 +2213,30 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err))
|
||||
}
|
||||
// Ensure the WebSocket is closed so senderClosed
|
||||
// completes and the handler can return.
|
||||
<-senderClosed
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// Mark the chat as read when the stream connects and again
|
||||
// when it disconnects so we avoid per-message API calls while
|
||||
// messages are actively streaming.
|
||||
api.markChatAsRead(ctx, chatID)
|
||||
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
|
||||
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: batch,
|
||||
})
|
||||
return encoder.Encode(batch)
|
||||
}
|
||||
|
||||
drainChatStreamBatch := func(
|
||||
@@ -2224,7 +2269,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
end = len(snapshot)
|
||||
}
|
||||
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -2233,8 +2278,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
case firstEvent, ok := <-events:
|
||||
if !ok {
|
||||
return
|
||||
@@ -2244,7 +2287,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chatStreamBatchSize,
|
||||
)
|
||||
if err := sendChatStreamBatch(batch); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if streamClosed {
|
||||
@@ -2259,6 +2302,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon != nil {
|
||||
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
||||
@@ -2272,8 +2316,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if updateErr != nil {
|
||||
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
||||
slog.F("chat_id", chatID), slog.Error(updateErr))
|
||||
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to interrupt chat.",
|
||||
Detail: updateErr.Error(),
|
||||
@@ -3183,6 +3226,70 @@ func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// @Summary Get chat retention days
|
||||
// @ID get-chat-retention-days
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Produce json
|
||||
// @Success 200 {object} codersdk.ChatRetentionDaysResponse
|
||||
// @Router /experimental/chats/config/retention-days [get]
|
||||
// @x-apidocgen {"skip": true}
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
retentionDays, err := api.Database.GetChatRetentionDays(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat retention days.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatRetentionDaysResponse{
|
||||
RetentionDays: retentionDays,
|
||||
})
|
||||
}
|
||||
|
||||
// Keep in sync with retentionDaysMaximum in
|
||||
// site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx.
|
||||
const retentionDaysMaximum = 3650 // ~10 years
|
||||
|
||||
// @Summary Update chat retention days
|
||||
// @ID update-chat-retention-days
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Accept json
|
||||
// @Param request body codersdk.UpdateChatRetentionDaysRequest true "Request body"
|
||||
// @Success 204
|
||||
// @Router /experimental/chats/config/retention-days [put]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) putChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
var req codersdk.UpdateChatRetentionDaysRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if req.RetentionDays < 0 || req.RetentionDays > retentionDaysMaximum {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Retention days must be between 0 and %d.", retentionDaysMaximum),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := api.Database.UpsertChatRetentionDays(ctx, req.RetentionDays); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update chat retention days.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
@@ -5519,7 +5626,7 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
previousSummary database.GetPRInsightsSummaryRow
|
||||
timeSeries []database.GetPRInsightsTimeSeriesRow
|
||||
byModel []database.GetPRInsightsPerModelRow
|
||||
recentPRs []database.GetPRInsightsRecentPRsRow
|
||||
recentPRs []database.GetPRInsightsPullRequestsRow
|
||||
)
|
||||
|
||||
eg, egCtx := errgroup.WithContext(ctx)
|
||||
@@ -5567,11 +5674,10 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
recentPRs, err = api.Database.GetPRInsightsRecentPRs(egCtx, database.GetPRInsightsRecentPRsParams{
|
||||
recentPRs, err = api.Database.GetPRInsightsPullRequests(egCtx, database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: ownerID,
|
||||
LimitVal: 20,
|
||||
})
|
||||
return err
|
||||
})
|
||||
@@ -5681,9 +5787,83 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PRInsightsResponse{
|
||||
Summary: summary,
|
||||
TimeSeries: tsEntries,
|
||||
ByModel: modelEntries,
|
||||
RecentPRs: prEntries,
|
||||
Summary: summary,
|
||||
TimeSeries: tsEntries,
|
||||
ByModel: modelEntries,
|
||||
PullRequests: prEntries,
|
||||
})
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) postChatToolResults(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
// Cap the raw request body to prevent excessive memory use.
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
||||
var req codersdk.SubmitToolResultsRequest
|
||||
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Results) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "At least one tool result is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Fast-path check outside the transaction. The authoritative
|
||||
// check happens inside SubmitToolResults under a row lock.
|
||||
if chat.Status != database.ChatStatusRequiresAction {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Chat is not waiting for tool results.",
|
||||
Detail: fmt.Sprintf("Chat status is %q, expected %q.", chat.Status, database.ChatStatusRequiresAction),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var dynamicTools json.RawMessage
|
||||
if chat.DynamicTools.Valid {
|
||||
dynamicTools = chat.DynamicTools.RawMessage
|
||||
}
|
||||
|
||||
err := api.chatDaemon.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: apiKey.UserID,
|
||||
ModelConfigID: chat.LastModelConfigID,
|
||||
Results: req.Results,
|
||||
DynamicTools: dynamicTools,
|
||||
})
|
||||
if err != nil {
|
||||
var validationErr *chatd.ToolResultValidationError
|
||||
var conflictErr *chatd.ToolResultStatusConflictError
|
||||
switch {
|
||||
case errors.As(err, &conflictErr):
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Chat is not waiting for tool results.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
case errors.As(err, &validationErr):
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: validationErr.Message,
|
||||
Detail: validationErr.Detail,
|
||||
})
|
||||
default:
|
||||
api.Logger.Error(ctx, "tool results submission failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error submitting tool results.",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
+630
-116
@@ -16,7 +16,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -268,17 +270,10 @@ func TestPostChats(t *testing.T) {
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Member without agents-access should be denied.
|
||||
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
// Strip the auto-assigned agents-access role to test
|
||||
// the denied case.
|
||||
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
|
||||
Roles: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
_, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -288,6 +283,7 @@ func TestPostChats(t *testing.T) {
|
||||
})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
|
||||
t.Run("HidesSystemPromptMessages", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -756,15 +752,7 @@ func TestListChats(t *testing.T) {
|
||||
// returning empty because no chats exist.
|
||||
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
// Strip the auto-assigned agents-access role to test
|
||||
// the denied case.
|
||||
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
|
||||
Roles: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
_, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: member.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
@@ -888,6 +876,186 @@ func TestListChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, allChats, totalChats)
|
||||
})
|
||||
|
||||
// Test that a pinned chat with an old updated_at appears on page 1.
|
||||
t.Run("PinnedOnFirstPage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create the chat that will later be pinned. It gets the
|
||||
// earliest updated_at because it is inserted first.
|
||||
pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "pinned-chat",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fill page 1 with newer chats so the pinned chat would
|
||||
// normally be pushed off the first page (default limit 50).
|
||||
const fillerCount = 51
|
||||
fillerChats := make([]codersdk.Chat, 0, fillerCount)
|
||||
for i := range fillerCount {
|
||||
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: fmt.Sprintf("filler-%d", i),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, createErr)
|
||||
fillerChats = append(fillerChats, c)
|
||||
}
|
||||
|
||||
// Wait for all chats to reach a terminal status so
|
||||
// updated_at is stable before paginating. A single
|
||||
// polling loop checks every chat per tick to avoid
|
||||
// O(N) separate Eventually loops.
|
||||
allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...)
|
||||
pending := make(map[uuid.UUID]struct{}, len(allCreated))
|
||||
for _, c := range allCreated {
|
||||
pending[c.ID] = struct{}{}
|
||||
}
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: fillerCount + 10},
|
||||
})
|
||||
if listErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, ch := range all {
|
||||
if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning {
|
||||
delete(pending, ch.ID)
|
||||
}
|
||||
}
|
||||
return len(pending) == 0
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Pin the earliest chat.
|
||||
err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch page 1 with default limit (50).
|
||||
page1, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: 50},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The pinned chat must appear on page 1.
|
||||
page1IDs := make(map[uuid.UUID]struct{}, len(page1))
|
||||
for _, c := range page1 {
|
||||
page1IDs[c.ID] = struct{}{}
|
||||
}
|
||||
_, found := page1IDs[pinnedChat.ID]
|
||||
require.True(t, found, "pinned chat should appear on page 1")
|
||||
|
||||
// The pinned chat should be the first item in the list.
|
||||
require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first")
|
||||
})
|
||||
|
||||
// Test cursor pagination with a mix of pinned and unpinned chats.
|
||||
t.Run("CursorWithPins", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create 5 chats: 2 will be pinned, 3 unpinned.
|
||||
const totalChats = 5
|
||||
createdChats := make([]codersdk.Chat, 0, totalChats)
|
||||
for i := range totalChats {
|
||||
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: fmt.Sprintf("cursor-pin-chat-%d", i),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, createErr)
|
||||
createdChats = append(createdChats, c)
|
||||
}
|
||||
|
||||
// Wait for all chats to reach terminal status.
|
||||
// Check each chat by ID rather than fetching the full list.
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
for _, c := range createdChats {
|
||||
ch, err := client.GetChat(ctx, c.ID)
|
||||
require.NoError(t, err, "GetChat should succeed for just-created chat %s", c.ID)
|
||||
if ch.Status == codersdk.ChatStatusPending || ch.Status == codersdk.ChatStatusRunning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Pin the first two chats (oldest updated_at).
|
||||
err := client.UpdateChat(ctx, createdChats[0].ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = client.UpdateChat(ctx, createdChats[1].ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Paginate with limit=2 using cursor (after_id).
|
||||
const pageSize = 2
|
||||
maxPages := totalChats/pageSize + 2
|
||||
var allPaginated []codersdk.Chat
|
||||
var afterID uuid.UUID
|
||||
for range maxPages {
|
||||
opts := &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: pageSize},
|
||||
}
|
||||
if afterID != uuid.Nil {
|
||||
opts.Pagination.AfterID = afterID
|
||||
}
|
||||
page, listErr := client.ListChats(ctx, opts)
|
||||
require.NoError(t, listErr)
|
||||
if len(page) == 0 {
|
||||
break
|
||||
}
|
||||
allPaginated = append(allPaginated, page...)
|
||||
afterID = page[len(page)-1].ID
|
||||
}
|
||||
|
||||
// All chats should appear exactly once.
|
||||
seenIDs := make(map[uuid.UUID]struct{}, len(allPaginated))
|
||||
for _, c := range allPaginated {
|
||||
_, dup := seenIDs[c.ID]
|
||||
require.False(t, dup, "chat %s appeared more than once", c.ID)
|
||||
seenIDs[c.ID] = struct{}{}
|
||||
}
|
||||
require.Len(t, seenIDs, totalChats, "all chats should appear in paginated results")
|
||||
|
||||
// Pinned chats should come before unpinned ones, and
|
||||
// within the pinned group, lower pin_order sorts first.
|
||||
pinnedSeen := false
|
||||
unpinnedSeen := false
|
||||
for _, c := range allPaginated {
|
||||
if c.PinOrder > 0 {
|
||||
require.False(t, unpinnedSeen, "pinned chat %s appeared after unpinned chat", c.ID)
|
||||
pinnedSeen = true
|
||||
} else {
|
||||
unpinnedSeen = true
|
||||
}
|
||||
}
|
||||
require.True(t, pinnedSeen, "at least one pinned chat should exist")
|
||||
|
||||
// Verify within-pinned ordering: pin_order=1 before
|
||||
// pin_order=2 (the -pin_order DESC column).
|
||||
require.Equal(t, createdChats[0].ID, allPaginated[0].ID,
|
||||
"pin_order=1 chat should be first")
|
||||
require.Equal(t, createdChats[1].ID, allPaginated[1].ID,
|
||||
"pin_order=2 chat should be second")
|
||||
})
|
||||
}
|
||||
|
||||
func TestListChatModels(t *testing.T) {
|
||||
@@ -1126,17 +1294,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1148,25 +1305,16 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1186,18 +1334,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Skip the initial ping.
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1210,18 +1346,11 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
var got codersdk.Chat
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
var update watchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &update); readErr != nil {
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil {
|
||||
return false
|
||||
}
|
||||
if update.Type != codersdk.ServerSentEventTypeData {
|
||||
return false
|
||||
}
|
||||
var payload coderdpubsub.ChatEvent
|
||||
if unmarshalErr := json.Unmarshal(update.Data, &payload); unmarshalErr != nil {
|
||||
return false
|
||||
}
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
got = payload.Chat
|
||||
return true
|
||||
@@ -1294,25 +1423,14 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Read the initial ping.
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
// Publish a diff_status_change event via pubsub,
|
||||
// mimicking what PublishDiffStatusChange does after
|
||||
// it reads the diff status from the DB.
|
||||
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindDiffStatusChange,
|
||||
Chat: codersdk.Chat{
|
||||
ID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
@@ -1325,25 +1443,15 @@ func TestWatchChats(t *testing.T) {
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read events until we find the diff_status_change.
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var received codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var received coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
|
||||
if received.Kind != codersdk.ChatWatchEventKindDiffStatusChange ||
|
||||
received.Chat.ID != chat.ID {
|
||||
continue
|
||||
}
|
||||
@@ -1362,7 +1470,6 @@ func TestWatchChats(t *testing.T) {
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1405,31 +1512,13 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
|
||||
collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent {
|
||||
t.Helper()
|
||||
|
||||
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
|
||||
events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3)
|
||||
for len(events) < 3 {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind != expectedKind {
|
||||
continue
|
||||
@@ -1439,7 +1528,7 @@ func TestWatchChats(t *testing.T) {
|
||||
return events
|
||||
}
|
||||
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) {
|
||||
t.Helper()
|
||||
|
||||
require.Len(t, events, 3)
|
||||
@@ -1452,12 +1541,12 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
|
||||
deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted)
|
||||
assertLifecycleEvents(deletedEvents, true)
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
require.NoError(t, err)
|
||||
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
|
||||
createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated)
|
||||
assertLifecycleEvents(createdEvents, false)
|
||||
})
|
||||
|
||||
@@ -7747,6 +7836,62 @@ func TestChatWorkspaceTTL(t *testing.T) {
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func TestChatRetentionDays(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
// Default value is 30 (days) when nothing has been configured.
|
||||
resp, err := adminClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "get default")
|
||||
require.Equal(t, int32(30), resp.RetentionDays, "default should be 30")
|
||||
|
||||
// Admin can set retention days to 90.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: 90,
|
||||
})
|
||||
require.NoError(t, err, "admin set 90")
|
||||
|
||||
resp, err = adminClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "get after set")
|
||||
require.Equal(t, int32(90), resp.RetentionDays, "should return 90")
|
||||
|
||||
// Non-admin member can read the value.
|
||||
resp, err = memberClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "member get")
|
||||
require.Equal(t, int32(90), resp.RetentionDays, "member should see same value")
|
||||
|
||||
// Non-admin member cannot write.
|
||||
err = memberClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{RetentionDays: 7})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
|
||||
// Admin can disable purge by setting 0.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: 0,
|
||||
})
|
||||
require.NoError(t, err, "admin set 0")
|
||||
|
||||
resp, err = adminClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "get after zero")
|
||||
require.Equal(t, int32(0), resp.RetentionDays, "should be 0 after disable")
|
||||
|
||||
// Validation: negative value is rejected.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: -1,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
|
||||
// Validation: exceeding the 3650-day maximum is rejected.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: 3651, // retentionDaysMaximum + 1; keep in sync with coderd/exp_chats.go.
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
|
||||
func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -8153,6 +8298,375 @@ func TestGetChatsByWorkspace(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSubmitToolResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// setupRequiresAction creates a chat via the DB with dynamic tools,
|
||||
// inserts an assistant message containing tool-call parts for each
|
||||
// given toolCallID, and sets the chat status to requires_action.
|
||||
// It returns the chat row so callers can exercise the endpoint.
|
||||
setupRequiresAction := func(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ownerID uuid.UUID,
|
||||
modelConfigID uuid.UUID,
|
||||
dynamicToolName string,
|
||||
toolCallIDs []string,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
// Marshal dynamic tools into the chat row.
|
||||
dynamicTools := []mcp.Tool{{
|
||||
Name: dynamicToolName,
|
||||
Description: "a test dynamic tool",
|
||||
InputSchema: mcp.ToolInputSchema{Type: "object"},
|
||||
}}
|
||||
dtJSON, err := json.Marshal(dynamicTools)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelConfigID,
|
||||
Title: "tool-results-test",
|
||||
DynamicTools: pqtype.NullRawMessage{RawMessage: dtJSON, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build assistant message with tool-call parts.
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(toolCallIDs))
|
||||
for _, id := range toolCallIDs {
|
||||
parts = append(parts, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: id,
|
||||
ToolName: dynamicToolName,
|
||||
Args: json.RawMessage(`{"key":"value"}`),
|
||||
})
|
||||
}
|
||||
content, err := chatprompt.MarshalParts(parts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfigID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Content: []string{string(content.RawMessage)},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Transition to requires_action.
|
||||
chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRequiresAction,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusRequiresAction, chat.Status)
|
||||
|
||||
return chat
|
||||
}
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_abc", "call_def"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_abc", Output: json.RawMessage(`"result_a"`)},
|
||||
{ToolCallID: "call_def", Output: json.RawMessage(`"result_b"`)},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify status is no longer requires_action. The chatd
|
||||
// loop may have already picked the chat up and
|
||||
// transitioned it further (pending → running → …), so we
|
||||
// accept any non-requires_action status.
|
||||
gotChat, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, codersdk.ChatStatusRequiresAction, gotChat.Status,
|
||||
"chat should no longer be in requires_action after submitting tool results")
|
||||
|
||||
// Verify tool-result messages were persisted.
|
||||
msgsResp, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var toolResultCount int
|
||||
for _, msg := range msgsResp.Messages {
|
||||
if msg.Role == codersdk.ChatMessageRoleTool {
|
||||
toolResultCount++
|
||||
}
|
||||
}
|
||||
require.Equal(t, len(toolCallIDs), toolResultCount,
|
||||
"expected one tool-result message per submitted result")
|
||||
})
|
||||
|
||||
t.Run("WrongStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
// Create a chat that is NOT in requires_action status.
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "wrong-status-test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_xyz", Output: json.RawMessage(`"nope"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusConflict)
|
||||
})
|
||||
|
||||
t.Run("MissingResult", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_one", "call_two"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// Submit only one of the two required results.
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_one", Output: json.RawMessage(`"partial"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("UnexpectedResult", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_real"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// Submit a result with a wrong tool_call_id.
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_bogus", Output: json.RawMessage(`"wrong"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("InvalidJSONOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_json"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// We must bypass the SDK client because json.RawMessage
|
||||
// rejects invalid JSON during json.Marshal. A raw HTTP
|
||||
// request lets the invalid payload reach the server so we
|
||||
// can verify server-side validation.
|
||||
rawBody := `{"results":[{"tool_call_id":"call_json","output":not-json,"is_error":false}]}`
|
||||
url := client.URL.JoinPath(fmt.Sprintf("/api/experimental/chats/%s/tool-results", chat.ID)).String()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBufferString(rawBody))
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("DuplicateToolCallID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_dup1", "call_dup2"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_a"`)},
|
||||
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_b"`)},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Contains(t, sdkErr.Message, "Duplicate tool_call_id")
|
||||
})
|
||||
|
||||
t.Run("EmptyResults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_empty"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("NotFoundForDifferentUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_other"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// Create a second user and try to submit tool results
|
||||
// to user A's chat.
|
||||
otherClientRaw, _ := coderdtest.CreateAnotherUser(
|
||||
t, client.Client, user.OrganizationID,
|
||||
rbac.RoleAgentsAccess(),
|
||||
)
|
||||
otherClient := codersdk.NewExperimentalClient(otherClientRaw)
|
||||
|
||||
err := otherClient.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_other", Output: json.RawMessage(`"nope"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostChats_DynamicToolValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("TooManyTools", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
tools := make([]codersdk.DynamicTool, 251)
|
||||
for i := range tools {
|
||||
tools[i] = codersdk.DynamicTool{
|
||||
Name: fmt.Sprintf("tool-%d", i),
|
||||
}
|
||||
}
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
UnsafeDynamicTools: tools,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Too many dynamic tools.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("EmptyToolName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
UnsafeDynamicTools: []codersdk.DynamicTool{
|
||||
{Name: ""},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Dynamic tool name must not be empty.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("DuplicateToolName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
UnsafeDynamicTools: []codersdk.DynamicTool{
|
||||
{Name: "dup-tool"},
|
||||
{Name: "dup-tool"},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Duplicate dynamic tool name.", sdkErr.Message)
|
||||
})
|
||||
}
|
||||
|
||||
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
package coderd
|
||||
|
||||
// InsertAgentChatTestModelConfig exposes insertAgentChatTestModelConfig for external tests.
|
||||
var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -146,12 +147,35 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
|
||||
cancel := params.redirectURL
|
||||
cancelQuery := params.redirectURL.Query()
|
||||
cancelQuery.Add("error", "access_denied")
|
||||
cancelQuery.Add("error_description", "The resource owner or authorization server denied the request")
|
||||
if params.state != "" {
|
||||
cancelQuery.Add("state", params.state)
|
||||
}
|
||||
cancel.RawQuery = cancelQuery.Encode()
|
||||
|
||||
cancelURI := cancel.String()
|
||||
if err := codersdk.ValidateRedirectURIScheme(cancel); err != nil {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadRequest,
|
||||
HideStatus: false,
|
||||
Title: "Invalid Callback URL",
|
||||
Description: "The application's registered callback URL has an invalid scheme.",
|
||||
Actions: []site.Action{
|
||||
{
|
||||
URL: accessURL.String(),
|
||||
Text: "Back to site",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
site.RenderOAuthAllowPage(rw, r, site.RenderOAuthAllowData{
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
CancelURI: cancel.String(),
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
// #nosec G203 -- The scheme is validated by
|
||||
// codersdk.ValidateRedirectURIScheme above.
|
||||
CancelURI: htmltemplate.URL(cancelURI),
|
||||
RedirectURI: r.URL.String(),
|
||||
CSRFToken: nosurf.Token(r),
|
||||
Username: ua.FriendlyName,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package oauth2provider_test
|
||||
|
||||
import (
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -20,7 +21,7 @@ func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
|
||||
|
||||
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
|
||||
AppName: "Test OAuth App",
|
||||
CancelURI: "https://coder.com/cancel",
|
||||
CancelURI: htmltemplate.URL("https://coder.com/cancel"),
|
||||
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
|
||||
CSRFToken: csrfFieldValue,
|
||||
Username: "test-user",
|
||||
|
||||
@@ -435,6 +435,9 @@ func convertProvisionerJobWithQueuePosition(pj database.GetProvisionerJobsByOrga
|
||||
if pj.WorkspaceID.Valid {
|
||||
job.Metadata.WorkspaceID = &pj.WorkspaceID.UUID
|
||||
}
|
||||
if pj.WorkspaceBuildTransition.Valid {
|
||||
job.Metadata.WorkspaceBuildTransition = codersdk.WorkspaceTransition(pj.WorkspaceBuildTransition.WorkspaceTransition)
|
||||
}
|
||||
return job
|
||||
}
|
||||
|
||||
|
||||
@@ -97,13 +97,14 @@ func TestProvisionerJobs(t *testing.T) {
|
||||
|
||||
// Verify that job metadata is correct.
|
||||
assert.Equal(t, job2.Metadata, codersdk.ProvisionerJobMetadata{
|
||||
TemplateVersionName: version.Name,
|
||||
TemplateID: template.ID,
|
||||
TemplateName: template.Name,
|
||||
TemplateDisplayName: template.DisplayName,
|
||||
TemplateIcon: template.Icon,
|
||||
WorkspaceID: &w.ID,
|
||||
WorkspaceName: w.Name,
|
||||
TemplateVersionName: version.Name,
|
||||
TemplateID: template.ID,
|
||||
TemplateName: template.Name,
|
||||
TemplateDisplayName: template.DisplayName,
|
||||
TemplateIcon: template.Icon,
|
||||
WorkspaceID: &w.ID,
|
||||
WorkspaceName: w.Name,
|
||||
WorkspaceBuildTransition: codersdk.WorkspaceTransitionStart,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
const ChatConfigEventChannel = "chat:config_change"
|
||||
|
||||
// HandleChatConfigEvent wraps a typed callback for ChatConfigEvent
|
||||
// messages, following the same pattern as HandleChatEvent.
|
||||
// messages, following the same pattern as HandleChatWatchEvent.
|
||||
func HandleChatConfigEvent(cb func(ctx context.Context, payload ChatConfigEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func ChatEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
|
||||
type ChatEvent struct {
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
}
|
||||
|
||||
type ChatEventKind string
|
||||
|
||||
const (
|
||||
ChatEventKindStatusChange ChatEventKind = "status_change"
|
||||
ChatEventKindTitleChange ChatEventKind = "title_change"
|
||||
ChatEventKindCreated ChatEventKind = "created"
|
||||
ChatEventKindDeleted ChatEventKind = "deleted"
|
||||
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// ChatWatchEventChannel returns the pubsub channel for chat
|
||||
// lifecycle events scoped to a single user.
|
||||
func ChatWatchEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
// HandleChatWatchEvent wraps a typed callback for
|
||||
// ChatWatchEvent messages delivered via pubsub.
|
||||
func HandleChatWatchEvent(cb func(ctx context.Context, payload codersdk.ChatWatchEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("chat watch event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("unmarshal chat watch event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
@@ -776,6 +776,40 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
|
||||
return nil
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
chats, err := r.options.Database.GetChatsUpdatedAfter(ctx, createdAfter)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chats updated after: %w", err)
|
||||
}
|
||||
snapshot.Chats = make([]Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
snapshot.Chats = append(snapshot.Chats, ConvertChat(chat))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
summaries, err := r.options.Database.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat message summaries: %w", err)
|
||||
}
|
||||
snapshot.ChatMessageSummaries = make([]ChatMessageSummary, 0, len(summaries))
|
||||
for _, s := range summaries {
|
||||
snapshot.ChatMessageSummaries = append(snapshot.ChatMessageSummaries, ConvertChatMessageSummary(s))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
configs, err := r.options.Database.GetChatModelConfigsForTelemetry(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat model configs: %w", err)
|
||||
}
|
||||
snapshot.ChatModelConfigs = make([]ChatModelConfig, 0, len(configs))
|
||||
for _, c := range configs {
|
||||
snapshot.ChatModelConfigs = append(snapshot.ChatModelConfigs, ConvertChatModelConfig(c))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1503,6 +1537,9 @@ type Snapshot struct {
|
||||
AIBridgeInterceptionsSummaries []AIBridgeInterceptionsSummary `json:"aibridge_interceptions_summaries"`
|
||||
BoundaryUsageSummary *BoundaryUsageSummary `json:"boundary_usage_summary"`
|
||||
FirstUserOnboarding *FirstUserOnboarding `json:"first_user_onboarding"`
|
||||
Chats []Chat `json:"chats"`
|
||||
ChatMessageSummaries []ChatMessageSummary `json:"chat_message_summaries"`
|
||||
ChatModelConfigs []ChatModelConfig `json:"chat_model_configs"`
|
||||
}
|
||||
|
||||
// Deployment contains information about the host running Coder.
|
||||
@@ -2113,6 +2150,66 @@ func ConvertTask(task database.Task) Task {
|
||||
return t
|
||||
}
|
||||
|
||||
// ConvertChat converts a database chat row to a telemetry Chat.
|
||||
func ConvertChat(dbChat database.GetChatsUpdatedAfterRow) Chat {
|
||||
c := Chat{
|
||||
ID: dbChat.ID,
|
||||
OwnerID: dbChat.OwnerID,
|
||||
CreatedAt: dbChat.CreatedAt,
|
||||
UpdatedAt: dbChat.UpdatedAt,
|
||||
Status: string(dbChat.Status),
|
||||
HasParent: dbChat.HasParent,
|
||||
Archived: dbChat.Archived,
|
||||
LastModelConfigID: dbChat.LastModelConfigID,
|
||||
}
|
||||
if dbChat.RootChatID.Valid {
|
||||
c.RootChatID = &dbChat.RootChatID.UUID
|
||||
}
|
||||
if dbChat.WorkspaceID.Valid {
|
||||
c.WorkspaceID = &dbChat.WorkspaceID.UUID
|
||||
}
|
||||
if dbChat.Mode.Valid {
|
||||
mode := string(dbChat.Mode.ChatMode)
|
||||
c.Mode = &mode
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// ConvertChatMessageSummary converts a database chat message
|
||||
// summary row to a telemetry ChatMessageSummary.
|
||||
func ConvertChatMessageSummary(dbRow database.GetChatMessageSummariesPerChatRow) ChatMessageSummary {
|
||||
return ChatMessageSummary{
|
||||
ChatID: dbRow.ChatID,
|
||||
MessageCount: dbRow.MessageCount,
|
||||
UserMessageCount: dbRow.UserMessageCount,
|
||||
AssistantMessageCount: dbRow.AssistantMessageCount,
|
||||
ToolMessageCount: dbRow.ToolMessageCount,
|
||||
SystemMessageCount: dbRow.SystemMessageCount,
|
||||
TotalInputTokens: dbRow.TotalInputTokens,
|
||||
TotalOutputTokens: dbRow.TotalOutputTokens,
|
||||
TotalReasoningTokens: dbRow.TotalReasoningTokens,
|
||||
TotalCacheCreationTokens: dbRow.TotalCacheCreationTokens,
|
||||
TotalCacheReadTokens: dbRow.TotalCacheReadTokens,
|
||||
TotalCostMicros: dbRow.TotalCostMicros,
|
||||
TotalRuntimeMs: dbRow.TotalRuntimeMs,
|
||||
DistinctModelCount: dbRow.DistinctModelCount,
|
||||
CompressedMessageCount: dbRow.CompressedMessageCount,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertChatModelConfig converts a database model config row to a
|
||||
// telemetry ChatModelConfig.
|
||||
func ConvertChatModelConfig(dbRow database.GetChatModelConfigsForTelemetryRow) ChatModelConfig {
|
||||
return ChatModelConfig{
|
||||
ID: dbRow.ID,
|
||||
Provider: dbRow.Provider,
|
||||
Model: dbRow.Model,
|
||||
ContextLimit: dbRow.ContextLimit,
|
||||
Enabled: dbRow.Enabled,
|
||||
IsDefault: dbRow.IsDefault,
|
||||
}
|
||||
}
|
||||
|
||||
type telemetryItemKey string
|
||||
|
||||
// The comment below gets rid of the warning that the name "TelemetryItemKey" has
|
||||
@@ -2234,6 +2331,53 @@ type BoundaryUsageSummary struct {
|
||||
PeriodDurationMilliseconds int64 `json:"period_duration_ms"`
|
||||
}
|
||||
|
||||
// Chat contains anonymized metadata about a chat for telemetry.
|
||||
// Titles and message content are excluded to avoid PII leakage.
|
||||
type Chat struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
OwnerID uuid.UUID `json:"owner_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Status string `json:"status"`
|
||||
HasParent bool `json:"has_parent"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id"`
|
||||
Mode *string `json:"mode"`
|
||||
Archived bool `json:"archived"`
|
||||
LastModelConfigID uuid.UUID `json:"last_model_config_id"`
|
||||
}
|
||||
|
||||
// ChatMessageSummary contains per-chat aggregated message metrics
|
||||
// for telemetry. Individual message content is never included.
|
||||
type ChatMessageSummary struct {
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
UserMessageCount int64 `json:"user_message_count"`
|
||||
AssistantMessageCount int64 `json:"assistant_message_count"`
|
||||
ToolMessageCount int64 `json:"tool_message_count"`
|
||||
SystemMessageCount int64 `json:"system_message_count"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalReasoningTokens int64 `json:"total_reasoning_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalCostMicros int64 `json:"total_cost_micros"`
|
||||
TotalRuntimeMs int64 `json:"total_runtime_ms"`
|
||||
DistinctModelCount int64 `json:"distinct_model_count"`
|
||||
CompressedMessageCount int64 `json:"compressed_message_count"`
|
||||
}
|
||||
|
||||
// ChatModelConfig contains model configuration metadata for
|
||||
// telemetry. Sensitive fields like API keys are excluded.
|
||||
type ChatModelConfig struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
ContextLimit int64 `json:"context_limit"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
func ConvertAIBridgeInterceptionsSummary(endTime time.Time, provider, model, client string, summary database.CalculateAIBridgeInterceptionsTelemetrySummaryRow) AIBridgeInterceptionsSummary {
|
||||
return AIBridgeInterceptionsSummary{
|
||||
ID: uuid.New(),
|
||||
|
||||
@@ -1549,3 +1549,303 @@ func TestTelemetry_BoundaryUsageSummary(t *testing.T) {
|
||||
require.Nil(t, snapshot2.BoundaryUsageSummary)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatsTelemetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
// Create chat providers (required FK for model configs).
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a model config.
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
DisplayName: "Claude Sonnet",
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 200000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a second model config to test full dump.
|
||||
modelCfg2, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o",
|
||||
DisplayName: "GPT-4o",
|
||||
Enabled: true,
|
||||
IsDefault: false,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a soft-deleted model config — should NOT appear in telemetry.
|
||||
deletedCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: "claude-deleted",
|
||||
DisplayName: "Deleted Model",
|
||||
Enabled: true,
|
||||
IsDefault: false,
|
||||
ContextLimit: 100000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.DeleteChatModelConfigByID(ctx, deletedCfg.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a root chat with a workspace.
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
OrganizationID: org.ID,
|
||||
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
|
||||
})
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
|
||||
CreatedBy: user.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
WorkspaceID: ws.ID,
|
||||
TemplateVersionID: tv.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
|
||||
rootChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "Root Chat",
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
Mode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a child chat (has parent + root).
|
||||
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg2.ID,
|
||||
Title: "Child Chat",
|
||||
Status: database.ChatStatusCompleted,
|
||||
ParentChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert messages for root chat: 2 user, 2 assistant, 1 tool.
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: rootChat.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID, uuid.Nil, user.ID, uuid.Nil, uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
|
||||
Content: []string{`[{"type":"text","text":"hello"}]`, `[{"type":"text","text":"hi"}]`, `[{"type":"text","text":"help"}]`, `[{"type":"text","text":"sure"}]`, `[{"type":"text","text":"result"}]`},
|
||||
ContentVersion: []int16{1, 1, 1, 1, 1},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{100, 200, 150, 300, 0},
|
||||
OutputTokens: []int64{0, 50, 0, 100, 0},
|
||||
TotalTokens: []int64{100, 250, 150, 400, 0},
|
||||
ReasoningTokens: []int64{0, 10, 0, 20, 0},
|
||||
CacheCreationTokens: []int64{50, 0, 30, 0, 0},
|
||||
CacheReadTokens: []int64{0, 25, 0, 40, 0},
|
||||
ContextLimit: []int64{200000, 200000, 200000, 200000, 200000},
|
||||
Compressed: []bool{false, false, false, false, false},
|
||||
TotalCostMicros: []int64{1000, 2000, 1500, 3000, 0},
|
||||
RuntimeMs: []int64{0, 500, 0, 800, 100},
|
||||
ProviderResponseID: []string{"", "resp-1", "", "resp-2", ""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert messages for child chat: 1 user, 1 assistant (compressed).
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: childChat.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID, uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelCfg2.ID, modelCfg2.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant},
|
||||
Content: []string{`[{"type":"text","text":"q"}]`, `[{"type":"text","text":"a"}]`},
|
||||
ContentVersion: []int16{1, 1},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{500, 600},
|
||||
OutputTokens: []int64{0, 200},
|
||||
TotalTokens: []int64{500, 800},
|
||||
ReasoningTokens: []int64{0, 50},
|
||||
CacheCreationTokens: []int64{100, 0},
|
||||
CacheReadTokens: []int64{0, 75},
|
||||
ContextLimit: []int64{128000, 128000},
|
||||
Compressed: []bool{false, true},
|
||||
TotalCostMicros: []int64{5000, 8000},
|
||||
RuntimeMs: []int64{0, 1200},
|
||||
ProviderResponseID: []string{"", "resp-3"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert a soft-deleted message on root chat with large token values.
|
||||
// This acts as "poison" — if the deleted filter is missing, totals
|
||||
// will be inflated and assertions below will fail.
|
||||
poisonMsgs, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: rootChat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelCfg.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
Content: []string{`[{"type":"text","text":"poison"}]`},
|
||||
ContentVersion: []int16{1},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{999999},
|
||||
OutputTokens: []int64{999999},
|
||||
TotalTokens: []int64{999999},
|
||||
ReasoningTokens: []int64{999999},
|
||||
CacheCreationTokens: []int64{999999},
|
||||
CacheReadTokens: []int64{999999},
|
||||
ContextLimit: []int64{200000},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{999999},
|
||||
RuntimeMs: []int64{999999},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.SoftDeleteChatMessageByID(ctx, poisonMsgs[0].ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, snapshot := collectSnapshot(ctx, t, db, nil)
|
||||
|
||||
// --- Assert Chats ---
|
||||
require.Len(t, snapshot.Chats, 2)
|
||||
|
||||
// Find root and child by HasParent flag.
|
||||
var foundRoot, foundChild *telemetry.Chat
|
||||
for i := range snapshot.Chats {
|
||||
if !snapshot.Chats[i].HasParent {
|
||||
foundRoot = &snapshot.Chats[i]
|
||||
} else {
|
||||
foundChild = &snapshot.Chats[i]
|
||||
}
|
||||
}
|
||||
require.NotNil(t, foundRoot, "expected root chat")
|
||||
require.NotNil(t, foundChild, "expected child chat")
|
||||
|
||||
// Root chat assertions.
|
||||
assert.Equal(t, rootChat.ID, foundRoot.ID)
|
||||
assert.Equal(t, user.ID, foundRoot.OwnerID)
|
||||
assert.Equal(t, "running", foundRoot.Status)
|
||||
assert.False(t, foundRoot.HasParent)
|
||||
assert.Nil(t, foundRoot.RootChatID)
|
||||
require.NotNil(t, foundRoot.WorkspaceID)
|
||||
assert.Equal(t, ws.ID, *foundRoot.WorkspaceID)
|
||||
assert.Equal(t, modelCfg.ID, foundRoot.LastModelConfigID)
|
||||
require.NotNil(t, foundRoot.Mode)
|
||||
assert.Equal(t, "computer_use", *foundRoot.Mode)
|
||||
assert.False(t, foundRoot.Archived)
|
||||
|
||||
// Child chat assertions.
|
||||
assert.Equal(t, childChat.ID, foundChild.ID)
|
||||
assert.Equal(t, user.ID, foundChild.OwnerID)
|
||||
assert.True(t, foundChild.HasParent)
|
||||
require.NotNil(t, foundChild.RootChatID)
|
||||
assert.Equal(t, rootChat.ID, *foundChild.RootChatID)
|
||||
assert.Nil(t, foundChild.WorkspaceID)
|
||||
assert.Equal(t, "completed", foundChild.Status)
|
||||
assert.Equal(t, modelCfg2.ID, foundChild.LastModelConfigID)
|
||||
assert.Nil(t, foundChild.Mode)
|
||||
assert.False(t, foundChild.Archived)
|
||||
|
||||
// --- Assert ChatMessageSummaries ---
|
||||
require.Len(t, snapshot.ChatMessageSummaries, 2)
|
||||
|
||||
summaryMap := make(map[uuid.UUID]telemetry.ChatMessageSummary)
|
||||
for _, s := range snapshot.ChatMessageSummaries {
|
||||
summaryMap[s.ChatID] = s
|
||||
}
|
||||
|
||||
// Root chat summary: 2 user + 2 assistant + 1 tool = 5 messages.
|
||||
rootSummary, ok := summaryMap[rootChat.ID]
|
||||
require.True(t, ok, "expected summary for root chat")
|
||||
assert.Equal(t, int64(5), rootSummary.MessageCount)
|
||||
assert.Equal(t, int64(2), rootSummary.UserMessageCount)
|
||||
assert.Equal(t, int64(2), rootSummary.AssistantMessageCount)
|
||||
assert.Equal(t, int64(1), rootSummary.ToolMessageCount)
|
||||
assert.Equal(t, int64(0), rootSummary.SystemMessageCount)
|
||||
assert.Equal(t, int64(750), rootSummary.TotalInputTokens) // 100+200+150+300+0
|
||||
assert.Equal(t, int64(150), rootSummary.TotalOutputTokens) // 0+50+0+100+0
|
||||
assert.Equal(t, int64(30), rootSummary.TotalReasoningTokens) // 0+10+0+20+0
|
||||
assert.Equal(t, int64(80), rootSummary.TotalCacheCreationTokens) // 50+0+30+0+0
|
||||
assert.Equal(t, int64(65), rootSummary.TotalCacheReadTokens) // 0+25+0+40+0
|
||||
assert.Equal(t, int64(7500), rootSummary.TotalCostMicros) // 1000+2000+1500+3000+0
|
||||
assert.Equal(t, int64(1400), rootSummary.TotalRuntimeMs) // 0+500+0+800+100
|
||||
assert.Equal(t, int64(1), rootSummary.DistinctModelCount)
|
||||
assert.Equal(t, int64(0), rootSummary.CompressedMessageCount)
|
||||
|
||||
// Child chat summary: 1 user + 1 assistant = 2 messages, 1 compressed.
|
||||
childSummary, ok := summaryMap[childChat.ID]
|
||||
require.True(t, ok, "expected summary for child chat")
|
||||
assert.Equal(t, int64(2), childSummary.MessageCount)
|
||||
assert.Equal(t, int64(1), childSummary.UserMessageCount)
|
||||
assert.Equal(t, int64(1), childSummary.AssistantMessageCount)
|
||||
assert.Equal(t, int64(1100), childSummary.TotalInputTokens) // 500+600
|
||||
assert.Equal(t, int64(200), childSummary.TotalOutputTokens) // 0+200
|
||||
assert.Equal(t, int64(50), childSummary.TotalReasoningTokens) // 0+50
|
||||
assert.Equal(t, int64(0), childSummary.ToolMessageCount)
|
||||
assert.Equal(t, int64(0), childSummary.SystemMessageCount)
|
||||
assert.Equal(t, int64(100), childSummary.TotalCacheCreationTokens) // 100+0
|
||||
assert.Equal(t, int64(75), childSummary.TotalCacheReadTokens) // 0+75
|
||||
assert.Equal(t, int64(13000), childSummary.TotalCostMicros) // 5000+8000
|
||||
assert.Equal(t, int64(1200), childSummary.TotalRuntimeMs) // 0+1200
|
||||
assert.Equal(t, int64(1), childSummary.DistinctModelCount)
|
||||
assert.Equal(t, int64(1), childSummary.CompressedMessageCount)
|
||||
|
||||
// --- Assert ChatModelConfigs ---
|
||||
require.Len(t, snapshot.ChatModelConfigs, 2)
|
||||
|
||||
configMap := make(map[uuid.UUID]telemetry.ChatModelConfig)
|
||||
for _, c := range snapshot.ChatModelConfigs {
|
||||
configMap[c.ID] = c
|
||||
}
|
||||
|
||||
cfg1, ok := configMap[modelCfg.ID]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "anthropic", cfg1.Provider)
|
||||
assert.Equal(t, "claude-sonnet-4-20250514", cfg1.Model)
|
||||
assert.Equal(t, int64(200000), cfg1.ContextLimit)
|
||||
assert.True(t, cfg1.Enabled)
|
||||
assert.True(t, cfg1.IsDefault)
|
||||
|
||||
cfg2, ok := configMap[modelCfg2.ID]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "openai", cfg2.Provider)
|
||||
assert.Equal(t, "gpt-4o", cfg2.Model)
|
||||
assert.Equal(t, int64(128000), cfg2.ContextLimit)
|
||||
assert.True(t, cfg2.Enabled)
|
||||
assert.False(t, cfg2.IsDefault)
|
||||
}
|
||||
|
||||
@@ -1638,18 +1638,6 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
|
||||
rbacRoles = req.RBACRoles
|
||||
}
|
||||
|
||||
// When the agents experiment is enabled, auto-assign the
|
||||
// agents-access role so new users can use Coder Agents
|
||||
// without manual admin intervention. Skip this for OIDC
|
||||
// users when site role sync is enabled, because the sync
|
||||
// will overwrite roles on every login anyway — those
|
||||
// admins should use --oidc-user-role-default instead.
|
||||
if api.Experiments.Enabled(codersdk.ExperimentAgents) &&
|
||||
!(req.LoginType == database.LoginTypeOIDC && api.IDPSync.SiteRoleSyncEnabled()) &&
|
||||
!slices.Contains(rbacRoles, codersdk.RoleAgentsAccess) {
|
||||
rbacRoles = append(rbacRoles, codersdk.RoleAgentsAccess)
|
||||
}
|
||||
|
||||
var user database.User
|
||||
err := store.InTx(func(tx database.Store) error {
|
||||
orgRoles := make([]string, 0)
|
||||
|
||||
@@ -829,35 +829,6 @@ func TestPostUsers(t *testing.T) {
|
||||
assert.Equal(t, firstUser.OrganizationID, user.OrganizationIDs[0])
|
||||
})
|
||||
|
||||
// CreateWithAgentsExperiment verifies that new users
|
||||
// are auto-assigned the agents-access role when the
|
||||
// experiment is enabled. The experiment-disabled case
|
||||
// is implicitly covered by TestInitialRoles, which
|
||||
// asserts exactly [owner] with no experiment — it
|
||||
// would fail if agents-access leaked through.
|
||||
t.Run("CreateWithAgentsExperiment", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentAgents)}
|
||||
client := coderdtest.New(t, &coderdtest.Options{DeploymentValues: dv})
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{firstUser.OrganizationID},
|
||||
Email: "another@user.org",
|
||||
Username: "someone-else",
|
||||
Password: "SomeSecurePassword!",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
roles, err := client.UserRoles(ctx, user.Username)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, roles.Roles, codersdk.RoleAgentsAccess,
|
||||
"new user should have agents-access role when agents experiment is enabled")
|
||||
})
|
||||
|
||||
t.Run("CreateWithStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
auditor := audit.NewMock()
|
||||
|
||||
@@ -0,0 +1,280 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// @Summary Create a new user secret
|
||||
// @ID create-a-new-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param request body codersdk.CreateUserSecretRequest true "Create secret request"
|
||||
// @Success 201 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [post]
|
||||
func (api *API) postUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
var req codersdk.CreateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Name is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Value == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Value is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := codersdk.UserSecretFilePathValid(req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
secret, err := api.Database.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Value: req.Value,
|
||||
ValueKeyID: sql.NullString{},
|
||||
EnvName: req.EnvName,
|
||||
FilePath: req.FilePath,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "A secret with that name, environment variable, or file path already exists.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error creating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary List user secrets
|
||||
// @ID list-user-secrets
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Success 200 {array} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [get]
|
||||
func (api *API) getUserSecrets(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
secrets, err := api.Database.ListUserSecrets(ctx, user.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error listing secrets.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecrets(secrets))
|
||||
}
|
||||
|
||||
// @Summary Get a user secret by name
|
||||
// @ID get-a-user-secret-by-name
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [get]
|
||||
func (api *API) getUserSecret(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
secret, err := api.Database.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Update a user secret
|
||||
// @ID update-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Param request body codersdk.UpdateUserSecretRequest true "Update secret request"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [patch]
|
||||
func (api *API) patchUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var req codersdk.UpdateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Value == nil && req.Description == nil && req.EnvName == nil && req.FilePath == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "At least one field must be provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(*req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
if err := codersdk.UserSecretFilePathValid(*req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
params := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
UpdateValue: req.Value != nil,
|
||||
Value: "",
|
||||
ValueKeyID: sql.NullString{},
|
||||
UpdateDescription: req.Description != nil,
|
||||
Description: "",
|
||||
UpdateEnvName: req.EnvName != nil,
|
||||
EnvName: "",
|
||||
UpdateFilePath: req.FilePath != nil,
|
||||
FilePath: "",
|
||||
}
|
||||
if req.Value != nil {
|
||||
params.Value = *req.Value
|
||||
}
|
||||
if req.Description != nil {
|
||||
params.Description = *req.Description
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
params.EnvName = *req.EnvName
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
params.FilePath = *req.FilePath
|
||||
}
|
||||
|
||||
secret, err := api.Database.UpdateUserSecretByUserIDAndName(ctx, params)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Update would conflict with an existing secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Delete a user secret
|
||||
// @ID delete-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 204
|
||||
// @Router /users/{user}/secrets/{name} [delete]
|
||||
func (api *API) deleteUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
rowsAffected, err := api.Database.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error deleting secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestPostUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
secret, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
Description: "Personal GitHub PAT",
|
||||
EnvName: "GITHUB_TOKEN",
|
||||
FilePath: "~/.github-token",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "github-token", secret.Name)
|
||||
assert.Equal(t, "Personal GitHub PAT", secret.Description)
|
||||
assert.Equal(t, "GITHUB_TOKEN", secret.EnvName)
|
||||
assert.Equal(t, "~/.github-token", secret.FilePath)
|
||||
assert.NotZero(t, secret.ID)
|
||||
assert.NotZero(t, secret.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("MissingName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Value: "some-value",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Name is required")
|
||||
})
|
||||
|
||||
t.Run("MissingValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "missing-value-secret",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Value is required")
|
||||
})
|
||||
|
||||
t.Run("DuplicateName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value2",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-1",
|
||||
Value: "value1",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-2",
|
||||
Value: "value2",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-2",
|
||||
Value: "value2",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "invalid-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "1INVALID",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ReservedEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "reserved-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "PATH",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("CoderPrefixEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "coder-prefix-secret",
|
||||
Value: "value",
|
||||
EnvName: "CODER_AGENT_TOKEN",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "bad-path-secret",
|
||||
Value: "value",
|
||||
FilePath: "relative/path",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Verify no secrets exist on a fresh user.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, secrets)
|
||||
|
||||
t.Run("WithSecrets", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-a",
|
||||
Value: "value-a",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-b",
|
||||
Value: "value-b",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 2)
|
||||
// Sorted by name.
|
||||
assert.Equal(t, "list-secret-a", secrets[0].Name)
|
||||
assert.Equal(t, "list-secret-b", secrets[1].Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
created, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "get-found-secret",
|
||||
Value: "my-value",
|
||||
EnvName: "GET_FOUND_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := client.UserSecretByName(ctx, codersdk.Me, "get-found-secret")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
assert.Equal(t, "get-found-secret", got.Name)
|
||||
assert.Equal(t, "GET_FOUND_SECRET", got.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.UserSecretByName(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPatchUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("UpdateDescription", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-desc-secret",
|
||||
Value: "my-value",
|
||||
Description: "original",
|
||||
EnvName: "PATCH_DESC_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
newDesc := "updated"
|
||||
updated, err := client.UpdateUserSecret(ctx, codersdk.Me, "patch-desc-secret", codersdk.UpdateUserSecretRequest{
|
||||
Description: &newDesc,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated", updated.Description)
|
||||
// Other fields unchanged.
|
||||
assert.Equal(t, "PATCH_DESC_ENV", updated.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NoFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-nofields-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "patch-nofields-secret", codersdk.UpdateUserSecretRequest{})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
newVal := "new-value"
|
||||
_, err := client.UpdateUserSecret(ctx, codersdk.Me, "nonexistent", codersdk.UpdateUserSecretRequest{
|
||||
Value: &newVal,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-1",
|
||||
Value: "value1",
|
||||
EnvName: "CONFLICT_TAKEN_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "CONFLICT_TAKEN_ENV"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-env-2", codersdk.UpdateUserSecretRequest{
|
||||
EnvName: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/conflict-taken",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "/tmp/conflict-taken"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-fp-2", codersdk.UpdateUserSecretRequest{
|
||||
FilePath: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "delete-me-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.DeleteUserSecret(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone.
|
||||
_, err = client.UserSecretByName(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
err := client.DeleteUserSecret(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
+600
-2
@@ -42,6 +42,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
maputil "github.com/coder/coder/v2/coderd/util/maps"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
@@ -181,8 +183,9 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
if logEntry.Level == "" {
|
||||
// Default to "info" to support older agents that didn't have the level field.
|
||||
logEntry.Level = codersdk.LogLevelInfo
|
||||
@@ -2392,3 +2395,598 @@ func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.Wor
|
||||
}
|
||||
return sdk
|
||||
}
|
||||
|
||||
// maxChatContextParts caps the number of parts per request to
|
||||
// prevent unbounded message payloads.
|
||||
const maxChatContextParts = 100
|
||||
|
||||
// maxChatContextFileBytes caps each context-file part to the same
|
||||
// 64KiB budget used when the agent reads instruction files from disk.
|
||||
const maxChatContextFileBytes = 64 * 1024
|
||||
|
||||
// maxChatContextRequestBodyBytes caps the JSON request body size for
|
||||
// agent-added context to roughly the same per-part budget used when
|
||||
// reading instruction files from disk.
|
||||
const maxChatContextRequestBodyBytes int64 = maxChatContextParts * maxChatContextFileBytes
|
||||
|
||||
// sanitizeWorkspaceAgentContextFileContent applies prompt
|
||||
// sanitization, then enforces the 64KiB per-file budget. The
|
||||
// truncated flag is preserved when the caller already capped the
|
||||
// file before sending it.
|
||||
func sanitizeWorkspaceAgentContextFileContent(
|
||||
content string,
|
||||
truncated bool,
|
||||
) (string, bool) {
|
||||
content = chatd.SanitizePromptText(content)
|
||||
if len(content) > maxChatContextFileBytes {
|
||||
content = content[:maxChatContextFileBytes]
|
||||
truncated = true
|
||||
}
|
||||
return content, truncated
|
||||
}
|
||||
|
||||
// readChatContextBody reads and validates the request body for chat
|
||||
// context endpoints. It handles MaxBytesReader wrapping, error
|
||||
// responses, and body rewind. If the body is empty or whitespace-only
|
||||
// and allowEmpty is true, it returns false without writing an error.
|
||||
//
|
||||
//nolint:revive // Add and clear endpoints only differ by empty-body handling.
|
||||
func readChatContextBody(ctx context.Context, rw http.ResponseWriter, r *http.Request, dst any, allowEmpty bool) bool {
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, maxChatContextRequestBodyBytes)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "Request body too large.",
|
||||
Detail: fmt.Sprintf("Maximum request body size is %d bytes.", maxChatContextRequestBodyBytes),
|
||||
})
|
||||
return false
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return false
|
||||
}
|
||||
if allowEmpty && len(bytes.TrimSpace(body)) == 0 {
|
||||
r.Body = http.NoBody
|
||||
return false
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return httpapi.Read(ctx, rw, r, dst)
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.AddChatContextRequest
|
||||
if !readChatContextBody(ctx, rw, r, &req, false) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) > maxChatContextParts {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Too many context parts (%d). Maximum is %d.", len(req.Parts), maxChatContextParts),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Filter to only non-empty context-file and skill parts.
|
||||
filtered := chatd.FilterContextParts(req.Parts, false)
|
||||
if len(filtered) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
req.Parts = filtered
|
||||
responsePartCount := 0
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
// We verify agent-to-chat ownership explicitly below.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Stamp each persisted part with the agent identity. Context-file
|
||||
// parts also get server-authoritative workspace metadata.
|
||||
directory := workspaceAgent.ExpandedDirectory
|
||||
if directory == "" {
|
||||
directory = workspaceAgent.Directory
|
||||
}
|
||||
for i := range req.Parts {
|
||||
req.Parts[i].ContextFileAgentID = uuid.NullUUID{
|
||||
UUID: workspaceAgent.ID,
|
||||
Valid: true,
|
||||
}
|
||||
if req.Parts[i].Type != codersdk.ChatMessagePartTypeContextFile {
|
||||
continue
|
||||
}
|
||||
req.Parts[i].ContextFileContent, req.Parts[i].ContextFileTruncated = sanitizeWorkspaceAgentContextFileContent(
|
||||
req.Parts[i].ContextFileContent,
|
||||
req.Parts[i].ContextFileTruncated,
|
||||
)
|
||||
req.Parts[i].ContextFileOS = workspaceAgent.OperatingSystem
|
||||
req.Parts[i].ContextFileDirectory = directory
|
||||
}
|
||||
req.Parts = chatd.FilterContextParts(req.Parts, false)
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
responsePartCount = len(req.Parts)
|
||||
|
||||
// Skill-only messages need a sentinel context-file part so the turn
|
||||
// pipeline trusts the associated skill metadata.
|
||||
req.Parts = prependAgentChatContextSentinelIfNeeded(
|
||||
req.Parts,
|
||||
workspaceAgent.ID,
|
||||
workspaceAgent.OperatingSystem,
|
||||
directory,
|
||||
)
|
||||
|
||||
content, err := chatprompt.MarshalParts(req.Parts)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal context parts.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Database.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(sysCtx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != workspaceAgent.ID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspace.OwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleChatMessageInsertParams(
|
||||
chat.ID,
|
||||
database.ChatMessageRoleUser,
|
||||
content,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
locked.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
uuid.Nil,
|
||||
)); err != nil {
|
||||
return xerrors.Errorf("insert context message: %w", err)
|
||||
}
|
||||
if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, tx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("rebuild injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to persist context message.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.AddChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
Count: responsePartCount,
|
||||
})
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentClearChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.ClearChatContextRequest
|
||||
populated := readChatContextBody(ctx, rw, r, &req, true)
|
||||
if !populated && r.Body != http.NoBody {
|
||||
return
|
||||
}
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
// Zero active chats is not an error for clear.
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{})
|
||||
return
|
||||
}
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = clearAgentChatContext(sysCtx, api.Database, chat.ID, workspaceAgent.ID, workspace.OwnerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to clear context from chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
errNoActiveChats = xerrors.New("no active chats found")
|
||||
errChatNotFound = xerrors.New("chat not found")
|
||||
errChatNotActive = xerrors.New("chat is not active")
|
||||
errChatDoesNotBelongToAgent = xerrors.New("chat does not belong to this agent")
|
||||
errChatDoesNotBelongToWorkspaceOwner = xerrors.New("chat does not belong to this workspace owner")
|
||||
)
|
||||
|
||||
type multipleActiveChatsError struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (e *multipleActiveChatsError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"multiple active chats (%d) found for this agent, specify a chat ID",
|
||||
e.count,
|
||||
)
|
||||
}
|
||||
|
||||
func resolveDefaultAgentChat(chats []database.Chat) (database.Chat, error) {
|
||||
switch len(chats) {
|
||||
case 0:
|
||||
return database.Chat{}, errNoActiveChats
|
||||
case 1:
|
||||
return chats[0], nil
|
||||
}
|
||||
|
||||
var rootChat *database.Chat
|
||||
for i := range chats {
|
||||
chat := &chats[i]
|
||||
if chat.ParentChatID.Valid {
|
||||
continue
|
||||
}
|
||||
if rootChat != nil {
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
rootChat = chat
|
||||
}
|
||||
if rootChat != nil {
|
||||
return *rootChat, nil
|
||||
}
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
|
||||
// resolveAgentChat finds the target chat from either an explicit ID
|
||||
// or auto-detection via the agent's active chats.
|
||||
func resolveAgentChat(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
explicitChatID uuid.UUID,
|
||||
) (database.Chat, error) {
|
||||
if explicitChatID == uuid.Nil {
|
||||
chats, err := db.GetActiveChatsByAgentID(ctx, agentID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("list active chats: %w", err)
|
||||
}
|
||||
ownerChats := make([]database.Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
continue
|
||||
}
|
||||
ownerChats = append(ownerChats, chat)
|
||||
}
|
||||
return resolveDefaultAgentChat(ownerChats)
|
||||
}
|
||||
|
||||
chat, err := db.GetChatByID(ctx, explicitChatID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return database.Chat{}, errChatNotFound
|
||||
}
|
||||
return database.Chat{}, xerrors.Errorf("get chat by id: %w", err)
|
||||
}
|
||||
if !chat.AgentID.Valid || chat.AgentID.UUID != agentID {
|
||||
return database.Chat{}, errChatDoesNotBelongToAgent
|
||||
}
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
return database.Chat{}, errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if !isActiveAgentChat(chat) {
|
||||
return database.Chat{}, errChatNotActive
|
||||
}
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
func isActiveAgentChat(chat database.Chat) bool {
|
||||
if chat.Archived {
|
||||
return false
|
||||
}
|
||||
|
||||
switch chat.Status {
|
||||
case database.ChatStatusWaiting,
|
||||
database.ChatStatusPending,
|
||||
database.ChatStatusRunning,
|
||||
database.ChatStatusPaused,
|
||||
database.ChatStatusRequiresAction:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func clearAgentChatContext(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
) error {
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(ctx, chatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != agentID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspaceOwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
messages, err := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
hadInjectedContext := locked.LastInjectedContext.Valid
|
||||
var skillOnlyMessageIDs []int64
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
hasContextFile := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeContextFile)
|
||||
hasSkill := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeSkill)
|
||||
if hasContextFile || hasSkill {
|
||||
hadInjectedContext = true
|
||||
}
|
||||
if hasSkill && !hasContextFile {
|
||||
skillOnlyMessageIDs = append(skillOnlyMessageIDs, msg.ID)
|
||||
}
|
||||
}
|
||||
if !hadInjectedContext {
|
||||
return nil
|
||||
}
|
||||
if err := tx.SoftDeleteContextFileMessages(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("soft delete context-file messages: %w", err)
|
||||
}
|
||||
for _, messageID := range skillOnlyMessageIDs {
|
||||
if err := tx.SoftDeleteChatMessageByID(ctx, messageID); err != nil {
|
||||
return xerrors.Errorf("soft delete context message %d: %w", messageID, err)
|
||||
}
|
||||
}
|
||||
// Reset provider-side Responses chaining so the next turn replays
|
||||
// the post-clear history instead of inheriting cleared context.
|
||||
if err := tx.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("clear provider response chain: %w", err)
|
||||
}
|
||||
// Clear the injected-context cache inside the transaction so it is
|
||||
// atomic with the soft-deletes.
|
||||
param, err := chatd.BuildLastInjectedContext(nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
if _, err := tx.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// prependAgentChatContextSentinelIfNeeded adds an empty context-file
|
||||
// part when the request only carries skills. The turn pipeline uses
|
||||
// the sentinel's agent metadata to trust the skill parts.
|
||||
func prependAgentChatContextSentinelIfNeeded(
|
||||
parts []codersdk.ChatMessagePart,
|
||||
agentID uuid.UUID,
|
||||
operatingSystem string,
|
||||
directory string,
|
||||
) []codersdk.ChatMessagePart {
|
||||
hasContextFile := false
|
||||
hasSkill := false
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasContextFile = true
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
hasSkill = true
|
||||
}
|
||||
if hasContextFile && hasSkill {
|
||||
return parts
|
||||
}
|
||||
}
|
||||
if !hasSkill || hasContextFile {
|
||||
return parts
|
||||
}
|
||||
return append([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: chatd.AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
ContextFileOS: operatingSystem,
|
||||
ContextFileDirectory: directory,
|
||||
}}, parts...)
|
||||
}
|
||||
|
||||
func sortChatMessagesByCreatedAtAndID(messages []database.ChatMessage) {
|
||||
sort.SliceStable(messages, func(i, j int) bool {
|
||||
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
|
||||
return messages[i].ID < messages[j].ID
|
||||
}
|
||||
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
|
||||
})
|
||||
}
|
||||
|
||||
// updateAgentChatLastInjectedContextFromMessages rebuilds the
|
||||
// injected-context cache from all persisted context-file and skill parts.
|
||||
func updateAgentChatLastInjectedContextFromMessages(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
) error {
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("load context messages for injected context: %w", err)
|
||||
}
|
||||
|
||||
sortChatMessagesByCreatedAtAndID(messages)
|
||||
|
||||
parts, err := chatd.CollectContextPartsFromMessages(ctx, logger, messages, true)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("collect injected context parts: %w", err)
|
||||
}
|
||||
parts = chatd.FilterContextPartsToLatestAgent(parts)
|
||||
|
||||
param, err := chatd.BuildLastInjectedContext(parts)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
if _, err := db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func messageHasPartTypes(raw []byte, types ...codersdk.ChatMessagePartType) bool {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
for _, typ := range types {
|
||||
if part.Type == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeAgentChatError translates resolveAgentChat errors to HTTP
|
||||
// responses.
|
||||
func writeAgentChatError(
|
||||
ctx context.Context,
|
||||
rw http.ResponseWriter,
|
||||
err error,
|
||||
) {
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "No active chats found for this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotFound) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Chat not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToAgent) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this workspace owner.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotActive) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Cannot modify context: this chat is no longer active.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var multipleErr *multipleActiveChatsError
|
||||
if errors.As(err, &multipleErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to resolve chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestActiveAgentChatDefinitionsAgree(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
workspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: owner.ID,
|
||||
}).WithAgent().Do()
|
||||
modelConfig := insertAgentChatTestModelConfig(ctx, t, db, owner.ID)
|
||||
|
||||
insertedChats := make([]database.Chat, 0, len(database.AllChatStatusValues())*2)
|
||||
for _, archived := range []bool{false, true} {
|
||||
for _, status := range database.AllChatStatusValues() {
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: status,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: fmt.Sprintf("%s-archived-%t", status, archived),
|
||||
AgentID: uuid.NullUUID{UUID: workspace.Agents[0].ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if archived {
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
insertedChats = append(insertedChats, chat)
|
||||
}
|
||||
}
|
||||
|
||||
activeChats, err := db.GetActiveChatsByAgentID(ctx, workspace.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeByID := make(map[uuid.UUID]bool, len(activeChats))
|
||||
for _, chat := range activeChats {
|
||||
activeByID[chat.ID] = true
|
||||
}
|
||||
|
||||
for _, chat := range insertedChats {
|
||||
require.Equalf(
|
||||
t,
|
||||
isActiveAgentChat(chat),
|
||||
activeByID[chat.ID],
|
||||
"status=%s archived=%t",
|
||||
chat.Status,
|
||||
chat.Archived,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestUpdateAgentChatLastInjectedContextFromMessagesUsesMessageIDTieBreaker(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
createdAt := time.Date(2026, time.April, 9, 13, 0, 0, 0, time.UTC)
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
|
||||
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
newContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{
|
||||
{
|
||||
ID: 2,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: newContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: oldContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
require.Equal(t, chatID, arg.ID)
|
||||
require.True(t, arg.LastInjectedContext.Valid)
|
||||
var cached []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(arg.LastInjectedContext.RawMessage, &cached))
|
||||
require.Len(t, cached, 1)
|
||||
require.Equal(t, "/new/AGENTS.md", cached[0].ContextFilePath)
|
||||
require.Equal(t, uuid.NullUUID{UUID: newAgentID, Valid: true}, cached[0].ContextFileAgentID)
|
||||
return database.Chat{}, nil
|
||||
},
|
||||
)
|
||||
|
||||
err = updateAgentChatLastInjectedContextFromMessages(
|
||||
context.Background(),
|
||||
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
db,
|
||||
chatID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func insertAgentChatTestModelConfig(
|
||||
ctx context.Context,
|
||||
t testing.TB,
|
||||
db database.Store,
|
||||
userID uuid.UUID,
|
||||
) database.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
createdBy := uuid.NullUUID{UUID: userID, Valid: true}
|
||||
|
||||
_, err := db.InsertChatProvider(sysCtx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: createdBy,
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(sysCtx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: createdBy,
|
||||
UpdatedBy: createdBy,
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return model
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,7 +91,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
require.Equal(t, tmpDir, workspace.LatestBuild.Resources[0].Agents[0].Directory)
|
||||
_, err = anotherClient.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
require.False(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
})
|
||||
t.Run("HasFallbackTroubleshootingURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -260,6 +260,50 @@ func TestWorkspaceAgentLogs(t *testing.T) {
|
||||
require.Equal(t, "testing", logChunk[0].Output)
|
||||
require.Equal(t, "testing2", logChunk[1].Output)
|
||||
})
|
||||
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
rawOutput := "before\x00after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
|
||||
Logs: []agentsdk.Log{
|
||||
{
|
||||
CreatedAt: dbtime.Now(),
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
|
||||
|
||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = closer.Close()
|
||||
}()
|
||||
|
||||
var logChunk []codersdk.WorkspaceAgentLog
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case logChunk = <-logs:
|
||||
}
|
||||
require.NoError(t, ctx.Err())
|
||||
require.Len(t, logChunk, 1)
|
||||
require.Equal(t, sanitizedOutput, logChunk[0].Output)
|
||||
})
|
||||
t.Run("Close logs on outdated build", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
+69
-12
@@ -213,6 +213,39 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Healthy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: echo.ProvisionGraphWithAgent(authToken),
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return assert.NoError(t, err) && workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Connecting", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
@@ -247,10 +280,10 @@ func TestWorkspace(t *testing.T) {
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{agent.ID}, workspace.Health.FailingAgents)
|
||||
assert.False(t, agent.Health.Healthy)
|
||||
assert.Equal(t, "agent has not yet connected", agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Unhealthy", func(t *testing.T) {
|
||||
@@ -302,6 +335,7 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
a1AuthToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -313,7 +347,9 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "a1",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: a1AuthToken,
|
||||
},
|
||||
}, {
|
||||
Id: uuid.NewString(),
|
||||
Name: "a2",
|
||||
@@ -330,13 +366,21 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, a1AuthToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return assert.NoError(t, err) && !workspace.Health.Healthy
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// Wait for the mixed state: a1 connected (healthy)
|
||||
// and workspace unhealthy (because a2 timed out).
|
||||
agent1 := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return agent1.Health.Healthy && !workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
@@ -360,6 +404,7 @@ func TestWorkspace(t *testing.T) {
|
||||
// disconnected, but this should not make the workspace unhealthy.
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -371,7 +416,9 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "parent",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
@@ -383,14 +430,23 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Get the workspace and parent agent.
|
||||
workspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
parentAgent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy initially")
|
||||
// Wait for the parent agent to connect and be healthy.
|
||||
var parentAgent codersdk.WorkspaceAgent
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
parentAgent = workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return parentAgent.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy")
|
||||
|
||||
// Create a sub-agent with a short connection timeout so it becomes
|
||||
// unhealthy quickly (simulating a devcontainer rebuild scenario).
|
||||
@@ -404,6 +460,7 @@ func TestWorkspace(t *testing.T) {
|
||||
// Wait for the sub-agent to become unhealthy due to timeout.
|
||||
var subAgentUnhealthy bool
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
+727
-82
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -70,14 +71,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
updatedChat.Title = wantTitle
|
||||
|
||||
messageEvents := make(chan struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}, 1)
|
||||
cancelSub, err := pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
|
||||
coderdpubsub.ChatWatchEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
messageEvents <- struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}{payload: payload, err: err}
|
||||
}),
|
||||
@@ -183,7 +184,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
select {
|
||||
case event := <-messageEvents:
|
||||
require.NoError(t, event.err)
|
||||
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, chatID, event.payload.Chat.ID)
|
||||
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
||||
case <-time.After(time.Second):
|
||||
@@ -233,14 +234,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
unlockedChat.StartedAt = sql.NullTime{}
|
||||
|
||||
messageEvents := make(chan struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}, 1)
|
||||
cancelSub, err := pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
|
||||
coderdpubsub.ChatWatchEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
messageEvents <- struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}{payload: payload, err: err}
|
||||
}),
|
||||
@@ -372,7 +373,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
select {
|
||||
case event := <-messageEvents:
|
||||
require.NoError(t, event.err)
|
||||
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, chatID, event.payload.Chat.ID)
|
||||
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
||||
case <-time.After(time.Second):
|
||||
@@ -703,7 +704,33 @@ func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
|
||||
gomock.Any(),
|
||||
agentID,
|
||||
).Return(workspaceAgent, nil).Times(1)
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(),
|
||||
gomock.Cond(func(x any) bool {
|
||||
arg, ok := x.(database.InsertChatMessagesParams)
|
||||
if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 {
|
||||
return false
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
foundMarker := false
|
||||
foundSkill := false
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" {
|
||||
foundMarker = true
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
|
||||
foundSkill = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return foundMarker && foundSkill
|
||||
}),
|
||||
).Return(nil, nil).Times(1)
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
||||
gomock.Cond(func(x any) bool {
|
||||
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
||||
@@ -2020,6 +2047,30 @@ func TestContextFileAgentID(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
instructionAgentID := uuid.New()
|
||||
sentinelAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/workspace/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true},
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: sentinelAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
}}),
|
||||
}
|
||||
id, ok := contextFileAgentID(msgs)
|
||||
require.Equal(t, instructionAgentID, id)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("SentinelWithoutAgentID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []database.ChatMessage{
|
||||
@@ -2036,6 +2087,492 @@ func TestContextFileAgentID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasPersistedInstructionFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
agentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}}),
|
||||
}
|
||||
require.False(t, hasPersistedInstructionFiles(msgs))
|
||||
})
|
||||
|
||||
t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
agentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/workspace/AGENTS.md",
|
||||
ContextFileContent: "repo instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
||||
}}),
|
||||
}
|
||||
require.True(t, hasPersistedInstructionFiles(msgs))
|
||||
})
|
||||
}
|
||||
|
||||
func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileOS: "darwin",
|
||||
ContextFileDirectory: "/old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}}),
|
||||
}
|
||||
|
||||
got := instructionFromContextFiles(msgs)
|
||||
require.Contains(t, got, "new instructions")
|
||||
require.Contains(t, got, "Operating System: linux")
|
||||
require.Contains(t, got, "Working Directory: /new")
|
||||
require.NotContains(t, got, "old instructions")
|
||||
require.NotContains(t, got, "Operating System: darwin")
|
||||
}
|
||||
|
||||
func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/legacy/AGENTS.md",
|
||||
ContextFileContent: "legacy instructions",
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileOS: "darwin",
|
||||
ContextFileDirectory: "/old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}}),
|
||||
}
|
||||
|
||||
got := instructionFromContextFiles(msgs)
|
||||
require.Contains(t, got, "legacy instructions")
|
||||
require.Contains(t, got, "new instructions")
|
||||
require.Contains(t, got, "Operating System: linux")
|
||||
require.Contains(t, got, "Working Directory: /new")
|
||||
require.NotContains(t, got, "old instructions")
|
||||
require.NotContains(t, got, "Operating System: darwin")
|
||||
}
|
||||
|
||||
func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-legacy",
|
||||
SkillDir: "/skills/repo-helper-legacy",
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-old",
|
||||
SkillDir: "/skills/repo-helper-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: newAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-new",
|
||||
SkillDir: "/skills/repo-helper-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
got := skillsFromParts(msgs)
|
||||
require.Equal(t, []chattool.SkillMeta{
|
||||
{Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"},
|
||||
{Name: "repo-helper-new", Dir: "/skills/repo-helper-new"},
|
||||
}, got)
|
||||
}
|
||||
|
||||
func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-old",
|
||||
SkillDir: "/skills/repo-helper-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: newAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-new",
|
||||
SkillDir: "/skills/repo-helper-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
got := skillsFromParts(msgs)
|
||||
require.Equal(t, []chattool.SkillMeta{{
|
||||
Name: "repo-helper-new",
|
||||
Dir: "/skills/repo-helper-new",
|
||||
}}, got)
|
||||
}
|
||||
|
||||
func TestMergeSkillMetas(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
persisted := []chattool.SkillMeta{{
|
||||
Name: "repo-helper",
|
||||
Description: "Persisted skill",
|
||||
Dir: "/skills/repo-helper-old",
|
||||
}}
|
||||
discovered := []chattool.SkillMeta{
|
||||
{
|
||||
Name: "repo-helper",
|
||||
Description: "Discovered replacement",
|
||||
Dir: "/skills/repo-helper-new",
|
||||
MetaFile: "SKILL.md",
|
||||
},
|
||||
{
|
||||
Name: "deep-review",
|
||||
Description: "Discovered skill",
|
||||
Dir: "/skills/deep-review",
|
||||
},
|
||||
}
|
||||
|
||||
got := mergeSkillMetas(persisted, discovered)
|
||||
require.Equal(t, []chattool.SkillMeta{
|
||||
discovered[0],
|
||||
discovered[1],
|
||||
}, got)
|
||||
}
|
||||
|
||||
func TestSelectSkillMetasForInstructionRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}}
|
||||
discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}}
|
||||
currentAgentID := uuid.New()
|
||||
otherAgentID := uuid.New()
|
||||
|
||||
t.Run("MergesCurrentAgentSkills", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := selectSkillMetasForInstructionRefresh(
|
||||
persisted,
|
||||
discovered,
|
||||
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
)
|
||||
require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got)
|
||||
})
|
||||
|
||||
t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := selectSkillMetasForInstructionRefresh(
|
||||
persisted,
|
||||
discovered,
|
||||
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
||||
)
|
||||
require.Equal(t, discovered, got)
|
||||
})
|
||||
|
||||
t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := selectSkillMetasForInstructionRefresh(
|
||||
persisted,
|
||||
nil,
|
||||
uuid.NullUUID{},
|
||||
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
||||
)
|
||||
require.Equal(t, persisted, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveChainModeIgnoresSkillOnlySentinelMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelConfigID := uuid.New()
|
||||
assistant := database.ChatMessage{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper",
|
||||
SkillDir: "/skills/repo-helper",
|
||||
},
|
||||
})
|
||||
skillOnly.Role = database.ChatMessageRoleUser
|
||||
user := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "latest user message",
|
||||
}})
|
||||
user.Role = database.ChatMessageRoleUser
|
||||
|
||||
got := resolveChainMode([]database.ChatMessage{assistant, skillOnly, user})
|
||||
require.Equal(t, "resp-123", got.previousResponseID)
|
||||
require.Equal(t, modelConfigID, got.modelConfigID)
|
||||
require.Equal(t, 2, got.trailingUserCount)
|
||||
require.Equal(t, 1, got.contributingTrailingUserCount)
|
||||
}
|
||||
|
||||
func TestFilterPromptForChainModeKeepsContributingUsersAcrossSkippedSentinelTurns(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelConfigID := uuid.New()
|
||||
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "prior user message",
|
||||
}})
|
||||
priorUser.Role = database.ChatMessageRoleUser
|
||||
assistant := database.ChatMessage{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
firstTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "first trailing user",
|
||||
}})
|
||||
firstTrailingUser.Role = database.ChatMessageRoleUser
|
||||
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper",
|
||||
SkillDir: "/skills/repo-helper",
|
||||
},
|
||||
})
|
||||
skillOnly.Role = database.ChatMessageRoleUser
|
||||
lastTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "last trailing user",
|
||||
}})
|
||||
lastTrailingUser.Role = database.ChatMessageRoleUser
|
||||
|
||||
chainInfo := resolveChainMode([]database.ChatMessage{
|
||||
priorUser,
|
||||
assistant,
|
||||
firstTrailingUser,
|
||||
skillOnly,
|
||||
lastTrailingUser,
|
||||
})
|
||||
require.Equal(t, 3, chainInfo.trailingUserCount)
|
||||
require.Equal(t, 2, chainInfo.contributingTrailingUserCount)
|
||||
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "system instruction"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "prior user message"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "assistant reply"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "first trailing user"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "last trailing user"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := filterPromptForChainMode(prompt, chainInfo)
|
||||
require.Len(t, got, 3)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
|
||||
|
||||
firstPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "first trailing user", firstPart.Text)
|
||||
lastPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[2].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "last trailing user", lastPart.Text)
|
||||
}
|
||||
|
||||
func TestFilterPromptForChainModeUsesContributingTrailingUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelConfigID := uuid.New()
|
||||
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "prior user message",
|
||||
}})
|
||||
priorUser.Role = database.ChatMessageRoleUser
|
||||
assistant := database.ChatMessage{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper",
|
||||
SkillDir: "/skills/repo-helper",
|
||||
},
|
||||
})
|
||||
skillOnly.Role = database.ChatMessageRoleUser
|
||||
latestUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "latest user message",
|
||||
}})
|
||||
latestUser.Role = database.ChatMessageRoleUser
|
||||
|
||||
chainInfo := resolveChainMode([]database.ChatMessage{
|
||||
priorUser,
|
||||
assistant,
|
||||
skillOnly,
|
||||
latestUser,
|
||||
})
|
||||
require.Equal(t, 2, chainInfo.trailingUserCount)
|
||||
require.Equal(t, 1, chainInfo.contributingTrailingUserCount)
|
||||
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "system instruction"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "prior user message"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "assistant reply"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "latest user message"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := filterPromptForChainMode(prompt, chainInfo)
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
|
||||
|
||||
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "latest user message", part.Text)
|
||||
}
|
||||
|
||||
func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage {
|
||||
raw, _ := json.Marshal(parts)
|
||||
return database.ChatMessage{
|
||||
|
||||
@@ -1531,6 +1531,70 @@ func TestRecoverStaleChatsPeriodically(t *testing.T) {
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestRecoverStaleRequiresActionChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Use a very short stale threshold so the periodic recovery
|
||||
// kicks in quickly during the test.
|
||||
staleAfter := 500 * time.Millisecond
|
||||
|
||||
// Create a chat and set it to requires_action to simulate a
|
||||
// client that disappeared while the chat was waiting for
|
||||
// dynamic tool results.
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
Title: "stale-requires-action",
|
||||
LastModelConfigID: model.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRequiresAction,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Backdate updated_at so the chat appears stale to the
|
||||
// recovery loop without needing time.Sleep.
|
||||
_, err = rawDB.ExecContext(ctx,
|
||||
"UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
time.Now().Add(-time.Hour), chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: testutil.WaitLong,
|
||||
InFlightChatStaleAfter: staleAfter,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
// The stale recovery should transition the requires_action
|
||||
// chat to error with the timeout message.
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
chatResult, err = db.GetChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return chatResult.Status == database.ChatStatusError
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
require.Contains(t, chatResult.LastError.String, "Dynamic tool execution timed out")
|
||||
require.False(t, chatResult.WorkerID.Valid)
|
||||
}
|
||||
|
||||
func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1882,6 +1946,518 @@ func TestPersistToolResultWithBinaryData(t *testing.T) {
|
||||
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include execute tool output")
|
||||
}
|
||||
|
||||
func TestDynamicToolCallPausesAndResumes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Track streaming calls to the mock LLM.
|
||||
var streamedCallCount atomic.Int32
|
||||
var streamedCallsMu sync.Mutex
|
||||
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
// Non-streaming requests are title generation — return a
|
||||
// simple title.
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("Dynamic tool test")
|
||||
}
|
||||
|
||||
// Capture the full request for later assertions.
|
||||
streamedCallsMu.Lock()
|
||||
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
|
||||
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
|
||||
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
|
||||
Stream: req.Stream,
|
||||
})
|
||||
streamedCallsMu.Unlock()
|
||||
|
||||
if streamedCallCount.Add(1) == 1 {
|
||||
// First call: the LLM invokes our dynamic tool.
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk(
|
||||
"my_dynamic_tool",
|
||||
`{"input":"hello world"}`,
|
||||
),
|
||||
)
|
||||
}
|
||||
// Second call: the LLM returns a normal text response.
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("Dynamic tool result received.")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
|
||||
|
||||
// Dynamic tools do not need a workspace connection, but the
|
||||
// chatd server always builds workspace tools. Use an active
|
||||
// server without an agent connection — the built-in tools
|
||||
// are never invoked because the only tool call targets our
|
||||
// dynamic tool.
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
|
||||
// Create a chat with a dynamic tool.
|
||||
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
||||
Name: "my_dynamic_tool",
|
||||
Description: "A test dynamic tool.",
|
||||
InputSchema: mcpgo.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{"type": "string"},
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "dynamic-tool-pause-resume",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Please call the dynamic tool."),
|
||||
},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. Wait for the chat to reach requires_action status.
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusRequiresAction ||
|
||||
got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
|
||||
"expected requires_action, got %s (last_error=%q)",
|
||||
chatResult.Status, chatResult.LastError.String)
|
||||
|
||||
// 2. Read the assistant message to find the tool-call ID.
|
||||
var toolCallID string
|
||||
var toolCallFound bool
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
if msg.Role != database.ChatMessageRoleAssistant {
|
||||
continue
|
||||
}
|
||||
parts, parseErr := chatprompt.ParseContent(msg)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
||||
toolCallID = part.ToolCallID
|
||||
toolCallFound = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.IntervalFast)
|
||||
require.True(t, toolCallFound, "expected to find tool call for my_dynamic_tool")
|
||||
require.NotEmpty(t, toolCallID)
|
||||
|
||||
// 3. Submit tool results via SubmitToolResults.
|
||||
toolResultOutput := json.RawMessage(`{"result":"dynamic tool output"}`)
|
||||
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: user.ID,
|
||||
ModelConfigID: chatResult.LastModelConfigID,
|
||||
Results: []codersdk.ToolResult{{
|
||||
ToolCallID: toolCallID,
|
||||
Output: toolResultOutput,
|
||||
}},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 4. Wait for the chat to reach a terminal status.
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
// 5. Verify the chat completed successfully.
|
||||
if chatResult.Status == database.ChatStatusError {
|
||||
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
|
||||
}
|
||||
|
||||
// 6. Verify the mock received exactly 2 streaming calls.
|
||||
require.Equal(t, int32(2), streamedCallCount.Load(),
|
||||
"expected exactly 2 streaming calls to the LLM")
|
||||
|
||||
streamedCallsMu.Lock()
|
||||
recordedCalls := append([]chattest.OpenAIRequest(nil), streamedCalls...)
|
||||
streamedCallsMu.Unlock()
|
||||
require.Len(t, recordedCalls, 2)
|
||||
|
||||
// 7. Verify the dynamic tool appeared in the first call's tool list.
|
||||
var foundDynamicTool bool
|
||||
for _, tool := range recordedCalls[0].Tools {
|
||||
if tool.Function.Name == "my_dynamic_tool" {
|
||||
foundDynamicTool = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundDynamicTool,
|
||||
"expected 'my_dynamic_tool' in the first LLM call's tool list")
|
||||
|
||||
// 8. Verify the second call's messages contain the tool result.
|
||||
var foundToolResultInSecondCall bool
|
||||
for _, message := range recordedCalls[1].Messages {
|
||||
if message.Role != "tool" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(message.Content, "dynamic tool output") {
|
||||
foundToolResultInSecondCall = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundToolResultInSecondCall,
|
||||
"expected second LLM call to include the submitted dynamic tool result")
|
||||
}
|
||||
|
||||
func TestDynamicToolCallMixedWithBuiltIn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Track streaming calls to the mock LLM.
|
||||
var streamedCallCount atomic.Int32
|
||||
var streamedCallsMu sync.Mutex
|
||||
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("Mixed tool test")
|
||||
}
|
||||
|
||||
streamedCallsMu.Lock()
|
||||
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
|
||||
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
|
||||
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
|
||||
Stream: req.Stream,
|
||||
})
|
||||
streamedCallsMu.Unlock()
|
||||
|
||||
if streamedCallCount.Add(1) == 1 {
|
||||
// First call: return TWO tool calls in one
|
||||
// response — a built-in tool (read_file) and a
|
||||
// dynamic tool (my_dynamic_tool).
|
||||
builtinChunk := chattest.OpenAIToolCallChunk(
|
||||
"read_file",
|
||||
`{"path":"/tmp/test.txt"}`,
|
||||
)
|
||||
dynamicChunk := chattest.OpenAIToolCallChunk(
|
||||
"my_dynamic_tool",
|
||||
`{"input":"hello world"}`,
|
||||
)
|
||||
// Merge both tool calls into one chunk with
|
||||
// separate indices so the LLM appears to have
|
||||
// requested both tools simultaneously.
|
||||
mergedChunk := builtinChunk
|
||||
dynCall := dynamicChunk.Choices[0].ToolCalls[0]
|
||||
dynCall.Index = 1
|
||||
mergedChunk.Choices[0].ToolCalls = append(
|
||||
mergedChunk.Choices[0].ToolCalls,
|
||||
dynCall,
|
||||
)
|
||||
return chattest.OpenAIStreamingResponse(mergedChunk)
|
||||
}
|
||||
// Second call (after tool results): normal text
|
||||
// response.
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("All done.")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
|
||||
// Create a chat with a dynamic tool.
|
||||
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
||||
Name: "my_dynamic_tool",
|
||||
Description: "A test dynamic tool.",
|
||||
InputSchema: mcpgo.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{"type": "string"},
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "mixed-builtin-dynamic",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Call both tools."),
|
||||
},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. Wait for the chat to reach requires_action status.
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusRequiresAction ||
|
||||
got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
|
||||
"expected requires_action, got %s (last_error=%q)",
|
||||
chatResult.Status, chatResult.LastError.String)
|
||||
|
||||
// 2. Verify the built-in tool (read_file) was already
|
||||
// executed by checking that a tool result message
|
||||
// exists for it in the database.
|
||||
var builtinToolResultFound bool
|
||||
var toolCallID string
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
parts, parseErr := chatprompt.ParseContent(msg)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
// Check for the built-in tool result.
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "read_file" {
|
||||
builtinToolResultFound = true
|
||||
}
|
||||
// Find the dynamic tool call ID.
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
||||
toolCallID = part.ToolCallID
|
||||
}
|
||||
}
|
||||
}
|
||||
return builtinToolResultFound && toolCallID != ""
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
require.True(t, builtinToolResultFound,
|
||||
"expected read_file tool result in the DB before dynamic tool resolution")
|
||||
require.NotEmpty(t, toolCallID)
|
||||
|
||||
// 3. Submit dynamic tool results.
|
||||
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: user.ID,
|
||||
ModelConfigID: chatResult.LastModelConfigID,
|
||||
Results: []codersdk.ToolResult{{
|
||||
ToolCallID: toolCallID,
|
||||
Output: json.RawMessage(`{"result":"dynamic output"}`),
|
||||
}},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 4. Wait for the chat to complete.
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
if chatResult.Status == database.ChatStatusError {
|
||||
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
|
||||
}
|
||||
|
||||
// 5. Verify the LLM received exactly 2 streaming calls.
|
||||
require.Equal(t, int32(2), streamedCallCount.Load(),
|
||||
"expected exactly 2 streaming calls to the LLM")
|
||||
}
|
||||
|
||||
func TestSubmitToolResultsConcurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// The mock LLM returns a dynamic tool call on the first streaming
|
||||
// request, then a plain text reply on the second.
|
||||
var streamedCallCount atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("Concurrency test")
|
||||
}
|
||||
if streamedCallCount.Add(1) == 1 {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk(
|
||||
"my_dynamic_tool",
|
||||
`{"input":"hello"}`,
|
||||
),
|
||||
)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("Done.")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
|
||||
// Create a chat with a dynamic tool.
|
||||
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
||||
Name: "my_dynamic_tool",
|
||||
Description: "A test dynamic tool.",
|
||||
InputSchema: mcpgo.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{"type": "string"},
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "concurrency-tool-results",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Please call the dynamic tool."),
|
||||
},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the chat to reach requires_action status.
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusRequiresAction ||
|
||||
got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
|
||||
"expected requires_action, got %s (last_error=%q)",
|
||||
chatResult.Status, chatResult.LastError.String)
|
||||
|
||||
// Find the tool call ID from the assistant message.
|
||||
var toolCallID string
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
if msg.Role != database.ChatMessageRoleAssistant {
|
||||
continue
|
||||
}
|
||||
parts, parseErr := chatprompt.ParseContent(msg)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
||||
toolCallID = part.ToolCallID
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.IntervalFast)
|
||||
require.NotEmpty(t, toolCallID)
|
||||
|
||||
// Spawn N goroutines that all try to submit tool results at the
|
||||
// same time. Exactly one should succeed; the rest must get a
|
||||
// ToolResultStatusConflictError.
|
||||
const numGoroutines = 10
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
ready = make(chan struct{})
|
||||
successes atomic.Int32
|
||||
conflicts atomic.Int32
|
||||
unexpectedErrors = make(chan error, numGoroutines)
|
||||
)
|
||||
|
||||
for range numGoroutines {
|
||||
wg.Go(func() {
|
||||
// Wait for all goroutines to be ready.
|
||||
<-ready
|
||||
|
||||
submitErr := server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: user.ID,
|
||||
ModelConfigID: chatResult.LastModelConfigID,
|
||||
Results: []codersdk.ToolResult{{
|
||||
ToolCallID: toolCallID,
|
||||
Output: json.RawMessage(`{"result":"concurrent output"}`),
|
||||
}},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
|
||||
if submitErr == nil {
|
||||
successes.Add(1)
|
||||
return
|
||||
}
|
||||
var conflict *chatd.ToolResultStatusConflictError
|
||||
if errors.As(submitErr, &conflict) {
|
||||
conflicts.Add(1)
|
||||
return
|
||||
}
|
||||
// Collect unexpected errors for assertion
|
||||
// outside the goroutine (require.NoError
|
||||
// calls t.FailNow which is illegal here).
|
||||
unexpectedErrors <- submitErr
|
||||
})
|
||||
}
|
||||
// Release all goroutines at once.
|
||||
close(ready)
|
||||
|
||||
wg.Wait()
|
||||
close(unexpectedErrors)
|
||||
|
||||
for ue := range unexpectedErrors {
|
||||
require.NoError(t, ue, "unexpected error from SubmitToolResults")
|
||||
}
|
||||
|
||||
require.Equal(t, int32(1), successes.Load(),
|
||||
"expected exactly 1 goroutine to succeed")
|
||||
require.Equal(t, int32(numGoroutines-1), conflicts.Load(),
|
||||
"expected %d conflict errors", numGoroutines-1)
|
||||
}
|
||||
|
||||
func ptrRef[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user