Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 88a2c3644e | |||
| cdb1499631 | |||
| 9c52b0b862 | |||
| 0f4a784b62 | |||
| 4d1b687865 | |||
| 3b9cf94b63 | |||
| bd467ce443 | |||
| c67c93982b | |||
| 2f52de7cfc | |||
| 0552b927b2 | |||
| 16b1b6865d | |||
| 897533f08d | |||
| 3e25cc9238 | |||
| bb64cab8a5 | |||
| b149433138 | |||
| ee4dccb898 | |||
| 8dff1cbc57 | |||
| a62ead8588 | |||
| b68c14dd04 | |||
| 508114d484 | |||
| e0fbb0e4ec | |||
| 7bde763b66 | |||
| 36141fafad | |||
| 3462c31f43 | |||
| a0ea71b74c | |||
| 0a14bb529e | |||
| 2c32d84f12 | |||
| 76d89f59af | |||
| 1a3a92bd1b | |||
| 4018320614 | |||
| d9700baa8d | |||
| 82456ff62e | |||
| 83fd4cf5c2 | |||
| 38d4da82b9 | |||
| 19e0e0e8e6 | |||
| 1d0653cdab | |||
| 95cff8c5fb | |||
| ad2415ede7 | |||
| 1e40cea199 | |||
| 9d6557d173 | |||
| 224db483d7 | |||
| 8237822441 | |||
| 65bf7c3b18 | |||
| 76cbc580f0 | |||
| 391b22aef7 | |||
| f8e8f979a2 | |||
| fb0ed1162b | |||
| 3f519744aa | |||
| 2505f6245f | |||
| 29ad2c6201 | |||
| 27e5ff0a8e | |||
| 128a7c23e6 | |||
| efb19eb748 | |||
| 2c499484b7 | |||
| 33d9d0d875 | |||
| f219834f5c | |||
| 7a94a683c4 | |||
| 2e6fdf2344 | |||
| 3d139c1a24 | |||
| f957981c8b | |||
| 584c61acb5 | |||
| f95a5202bf | |||
| d954460380 | |||
| f4240bb8c1 | |||
| 7caef4987f | |||
| 9b91af8ab7 | |||
| 506fba9ebf |
@@ -18,35 +18,35 @@ The 5.x era resolves years of module system ambiguity and cleans house on legacy
|
||||
|
||||
The left column reflects patterns still common before TypeScript 5.x. Write the right column instead. The "Since" column tells you the minimum TypeScript version required.
|
||||
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}() | [\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}()\[\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
|
||||
## New capabilities
|
||||
|
||||
|
||||
@@ -91,12 +91,6 @@ updates:
|
||||
emotion:
|
||||
patterns:
|
||||
- "@emotion*"
|
||||
exclude-patterns:
|
||||
- "jest-runner-eslint"
|
||||
jest:
|
||||
patterns:
|
||||
- "jest"
|
||||
- "@types/jest"
|
||||
vite:
|
||||
patterns:
|
||||
- "vite*"
|
||||
|
||||
@@ -84,6 +84,7 @@ jobs:
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
@@ -139,6 +140,7 @@ jobs:
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
@@ -171,4 +173,6 @@ jobs:
|
||||
--base "$RELEASE_VERSION" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY"
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
|
||||
@@ -42,6 +42,7 @@ jobs:
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
@@ -116,6 +117,7 @@ jobs:
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
@@ -132,8 +134,19 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
gh pr create \
|
||||
--base "$RELEASE_BRANCH" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY"
|
||||
NEW_PR_URL=$(
|
||||
gh pr create \
|
||||
--base "$RELEASE_BRANCH" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
)
|
||||
|
||||
# Comment on the original PR to notify the author.
|
||||
COMMENT="Cherry-pick PR created: ${NEW_PR_URL}"
|
||||
if [ "$CONFLICT" = true ]; then
|
||||
COMMENT="${COMMENT} (⚠️ conflicts need manual resolution)"
|
||||
fi
|
||||
gh pr comment "$PR_NUMBER" --body "$COMMENT"
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# Ensures that only bug fixes are cherry-picked to release branches.
|
||||
# PRs targeting release/* must have a title starting with "fix:" or "fix(scope):".
|
||||
name: PR Cherry-Pick Check
|
||||
|
||||
on:
|
||||
# zizmor: ignore[dangerous-triggers] Only reads PR metadata and comments; does not checkout PR code.
|
||||
pull_request_target:
|
||||
types: [opened, reopened, edited]
|
||||
branches:
|
||||
- "release/*"
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-cherry-pick:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Check PR title for bug fix
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
const baseBranch = context.payload.pull_request.base.ref;
|
||||
const author = context.payload.pull_request.user.login;
|
||||
|
||||
console.log(`PR #${prNumber}: "${title}" -> ${baseBranch}`);
|
||||
|
||||
// Match conventional commit "fix:" or "fix(scope):" prefix.
|
||||
const isBugFix = /^fix(\(.+\))?:/.test(title);
|
||||
|
||||
if (isBugFix) {
|
||||
console.log("PR title indicates a bug fix. No action needed.");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("PR title does not indicate a bug fix. Commenting.");
|
||||
|
||||
// Check for an existing comment from this bot to avoid duplicates
|
||||
// on title edits.
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const marker = "<!-- cherry-pick-check -->";
|
||||
const existingComment = comments.find(
|
||||
(c) => c.body && c.body.includes(marker),
|
||||
);
|
||||
|
||||
const body = [
|
||||
marker,
|
||||
`👋 Hey @${author}!`,
|
||||
"",
|
||||
`This PR is targeting the \`${baseBranch}\` release branch, but its title does not start with \`fix:\` or \`fix(scope):\`.`,
|
||||
"",
|
||||
"Only **bug fixes** should be cherry-picked to release branches. If this is a bug fix, please update the PR title to match the conventional commit format:",
|
||||
"",
|
||||
"```",
|
||||
"fix: description of the bug fix",
|
||||
"fix(scope): description of the bug fix",
|
||||
"```",
|
||||
"",
|
||||
"If this is **not** a bug fix, it likely should not target a release branch.",
|
||||
].join("\n");
|
||||
|
||||
if (existingComment) {
|
||||
console.log(`Updating existing comment ${existingComment.id}.`);
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: existingComment.id,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body,
|
||||
});
|
||||
}
|
||||
|
||||
core.warning(
|
||||
`PR #${prNumber} targets ${baseBranch} but is not a bug fix. Title must start with "fix:" or "fix(scope):".`,
|
||||
);
|
||||
@@ -91,6 +91,59 @@ define atomic_write
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
endef
|
||||
|
||||
# Helper binary targets. Built with go build -o to avoid caching
|
||||
# link-stage executables in GOCACHE. Each binary is a real Make
|
||||
# target so parallel -j builds serialize correctly instead of
|
||||
# racing on the same output path.
|
||||
|
||||
_gen/bin/apitypings: $(wildcard scripts/apitypings/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apitypings
|
||||
|
||||
_gen/bin/auditdocgen: $(wildcard scripts/auditdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/auditdocgen
|
||||
|
||||
_gen/bin/check-scopes: $(wildcard scripts/check-scopes/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/check-scopes
|
||||
|
||||
_gen/bin/clidocgen: $(wildcard scripts/clidocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/clidocgen
|
||||
|
||||
_gen/bin/dbdump: $(wildcard coderd/database/gen/dump/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./coderd/database/gen/dump
|
||||
|
||||
_gen/bin/examplegen: $(wildcard scripts/examplegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/examplegen
|
||||
|
||||
_gen/bin/gensite: $(wildcard scripts/gensite/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/gensite
|
||||
|
||||
_gen/bin/apikeyscopesgen: $(wildcard scripts/apikeyscopesgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apikeyscopesgen
|
||||
|
||||
_gen/bin/metricsdocgen: $(wildcard scripts/metricsdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen
|
||||
|
||||
_gen/bin/metricsdocgen-scanner: $(wildcard scripts/metricsdocgen/scanner/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen/scanner
|
||||
|
||||
_gen/bin/modeloptionsgen: $(wildcard scripts/modeloptionsgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/modeloptionsgen
|
||||
|
||||
_gen/bin/typegen: $(wildcard scripts/typegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/typegen
|
||||
|
||||
# Shared temp directory for atomic writes. Lives at the project root
|
||||
# so all targets share the same filesystem, and is gitignored.
|
||||
# Order-only prerequisite: recipes that need it depend on | _gen
|
||||
@@ -201,6 +254,7 @@ endif
|
||||
|
||||
clean:
|
||||
rm -rf build/ site/build/ site/out/
|
||||
rm -rf _gen/bin
|
||||
mkdir -p build/
|
||||
git restore site/out/
|
||||
.PHONY: clean
|
||||
@@ -654,8 +708,8 @@ lint/go:
|
||||
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
|
||||
.PHONY: lint/go
|
||||
|
||||
lint/examples:
|
||||
go run ./scripts/examplegen/main.go -lint
|
||||
lint/examples: | _gen/bin/examplegen
|
||||
_gen/bin/examplegen -lint
|
||||
.PHONY: lint/examples
|
||||
|
||||
# Use shfmt to determine the shell files, takes editorconfig into consideration.
|
||||
@@ -693,8 +747,8 @@ lint/actions/zizmor:
|
||||
.PHONY: lint/actions/zizmor
|
||||
|
||||
# Verify api_key_scope enum contains all RBAC <resource>:<action> values.
|
||||
lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
lint/check-scopes: coderd/database/dump.sql | _gen/bin/check-scopes
|
||||
_gen/bin/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# Verify migrations do not hardcode the public schema.
|
||||
@@ -734,8 +788,8 @@ lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
|
||||
#
|
||||
# pre-commit uses two phases: gen+fmt first, then lint+build. This
|
||||
# avoids races where gen's `go run` creates temporary .go files that
|
||||
# lint's find-based checks pick up. Within each phase, targets run in
|
||||
# avoids races where gen creates temporary .go files that lint's
|
||||
# find-based checks pick up. Within each phase, targets run in
|
||||
# parallel via -j. It fails if any tracked files have unstaged
|
||||
# changes afterward.
|
||||
|
||||
@@ -949,8 +1003,8 @@ gen/mark-fresh:
|
||||
|
||||
# Runs migrations to output a dump of the database schema after migrations are
|
||||
# applied.
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql)
|
||||
go run ./coderd/database/gen/dump/main.go
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql) | _gen/bin/dbdump
|
||||
_gen/bin/dbdump
|
||||
touch "$@"
|
||||
|
||||
# Generates Go code for querying the database.
|
||||
@@ -1067,88 +1121,88 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
|
||||
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen _gen/bin/apitypings
|
||||
$(call atomic_write,_gen/bin/apitypings,./scripts/biome_format.sh)
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
(cd site/ && pnpm run gen:provisioner)
|
||||
touch "$@"
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen _gen/bin/gensite
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
go run ./scripts/gensite/ -icons "$$tmpfile" && \
|
||||
_gen/bin/gensite -icons "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
|
||||
$(call atomic_write,go run ./scripts/examplegen/main.go)
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen _gen/bin/examplegen
|
||||
$(call atomic_write,_gen/bin/examplegen)
|
||||
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac object)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go because `go run` compiles
|
||||
# coderd/rbac which includes it.
|
||||
# NOTE: depends on object_gen.go because the generator build
|
||||
# compiles coderd/rbac which includes it.
|
||||
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go | _gen
|
||||
coderd/rbac/object_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file first to avoid truncating the package
|
||||
# during build since the generator imports the rbac package.
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
|
||||
$(call atomic_write,_gen/bin/typegen rbac scopenames)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file to avoid truncating the target, which
|
||||
# would break the codersdk package and any parallel build targets.
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
|
||||
$(call atomic_write,_gen/bin/typegen rbac codersdk)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/apikeyscopesgen
|
||||
# Generate SDK constants for external API key scopes.
|
||||
$(call atomic_write,go run ./scripts/apikeyscopesgen)
|
||||
$(call atomic_write,_gen/bin/apikeyscopesgen)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac typescript,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen countries,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
|
||||
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen _gen/bin/modeloptionsgen
|
||||
$(call atomic_write,_gen/bin/modeloptionsgen | tail -n +2,./scripts/biome_format.sh)
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
|
||||
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen _gen/bin/metricsdocgen-scanner
|
||||
$(call atomic_write,_gen/bin/metricsdocgen-scanner)
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen _gen/bin/metricsdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
|
||||
_gen/bin/metricsdocgen --prometheus-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen _gen/bin/clidocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && \
|
||||
mkdir -p "$$tmpdir/docs/reference/cli" && \
|
||||
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" _gen/bin/clidocgen && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
|
||||
rm -rf "$$tmpdir"
|
||||
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen _gen/bin/auditdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
|
||||
_gen/bin/auditdocgen --audit-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
@@ -2862,6 +2862,126 @@ func TestAPI(t *testing.T) {
|
||||
"rebuilt agent should include updated display apps")
|
||||
})
|
||||
|
||||
// Verify that when a terraform-managed subagent is injected into
|
||||
// a devcontainer, the Directory field sent to Create reflects
|
||||
// the container-internal workspaceFolder from devcontainer
|
||||
// read-configuration, not the host-side workspace_folder from
|
||||
// the terraform resource. This is the scenario described in
|
||||
// https://linear.app/codercom/issue/PRODUCT-259:
|
||||
// 1. Non-terraform subagent → directory = /workspaces/foo (correct)
|
||||
// 2. Terraform subagent → directory was stuck on host path (bug)
|
||||
t.Run("TerraformDefinedSubAgentUsesContainerInternalDirectory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Dev Container tests are not supported on Windows (this test uses mocks but fails due to Windows paths)")
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitMedium)
|
||||
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
mCtrl = gomock.NewController(t)
|
||||
|
||||
terraformAgentID = uuid.New()
|
||||
containerID = "test-container-id"
|
||||
|
||||
// Given: A container with a host-side workspace folder.
|
||||
terraformContainer = codersdk.WorkspaceAgentContainer{
|
||||
ID: containerID,
|
||||
FriendlyName: "test-container",
|
||||
Image: "test-image",
|
||||
Running: true,
|
||||
CreatedAt: time.Now(),
|
||||
Labels: map[string]string{
|
||||
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project",
|
||||
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json",
|
||||
},
|
||||
}
|
||||
|
||||
// Given: A terraform-defined devcontainer whose
|
||||
// workspace_folder is the HOST-side path (set by provisioner).
|
||||
terraformDevcontainer = codersdk.WorkspaceAgentDevcontainer{
|
||||
ID: uuid.New(),
|
||||
Name: "terraform-devcontainer",
|
||||
WorkspaceFolder: "/home/coder/project",
|
||||
ConfigPath: "/home/coder/project/.devcontainer/devcontainer.json",
|
||||
SubagentID: uuid.NullUUID{UUID: terraformAgentID, Valid: true},
|
||||
}
|
||||
|
||||
fCCLI = &fakeContainerCLI{
|
||||
containers: codersdk.WorkspaceAgentListContainersResponse{
|
||||
Containers: []codersdk.WorkspaceAgentContainer{terraformContainer},
|
||||
},
|
||||
arch: runtime.GOARCH,
|
||||
}
|
||||
|
||||
// Given: devcontainer read-configuration returns the
|
||||
// CONTAINER-INTERNAL workspace folder.
|
||||
fDCCLI = &fakeDevcontainerCLI{
|
||||
upID: containerID,
|
||||
readConfig: agentcontainers.DevcontainerConfig{
|
||||
Workspace: agentcontainers.DevcontainerWorkspace{
|
||||
WorkspaceFolder: "/workspaces/project",
|
||||
},
|
||||
MergedConfiguration: agentcontainers.DevcontainerMergedConfiguration{
|
||||
Customizations: agentcontainers.DevcontainerMergedCustomizations{
|
||||
Coder: []agentcontainers.CoderCustomization{{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mSAC = acmock.NewMockSubAgentClient(mCtrl)
|
||||
createCalls = make(chan agentcontainers.SubAgent, 1)
|
||||
closed bool
|
||||
)
|
||||
|
||||
mSAC.EXPECT().List(gomock.Any()).Return([]agentcontainers.SubAgent{}, nil).AnyTimes()
|
||||
|
||||
mSAC.EXPECT().Create(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, agent agentcontainers.SubAgent) (agentcontainers.SubAgent, error) {
|
||||
agent.AuthToken = uuid.New()
|
||||
createCalls <- agent
|
||||
return agent, nil
|
||||
},
|
||||
).Times(1)
|
||||
|
||||
mSAC.EXPECT().Delete(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
|
||||
assert.True(t, closed, "Delete should only be called after Close")
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
|
||||
api := agentcontainers.NewAPI(logger,
|
||||
agentcontainers.WithContainerCLI(fCCLI),
|
||||
agentcontainers.WithDevcontainerCLI(fDCCLI),
|
||||
agentcontainers.WithDevcontainers(
|
||||
[]codersdk.WorkspaceAgentDevcontainer{terraformDevcontainer},
|
||||
[]codersdk.WorkspaceAgentScript{{ID: terraformDevcontainer.ID, LogSourceID: uuid.New()}},
|
||||
),
|
||||
agentcontainers.WithSubAgentClient(mSAC),
|
||||
agentcontainers.WithSubAgentURL("test-subagent-url"),
|
||||
agentcontainers.WithWatcher(watcher.NewNoop()),
|
||||
)
|
||||
api.Start()
|
||||
defer func() {
|
||||
closed = true
|
||||
api.Close()
|
||||
}()
|
||||
|
||||
// When: The devcontainer is created (triggering injection).
|
||||
err := api.CreateDevcontainer(terraformDevcontainer.WorkspaceFolder, terraformDevcontainer.ConfigPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: The subagent sent to Create has the correct
|
||||
// container-internal directory, not the host path.
|
||||
createdAgent := testutil.RequireReceive(ctx, t, createCalls)
|
||||
assert.Equal(t, terraformAgentID, createdAgent.ID,
|
||||
"agent should use terraform-defined ID")
|
||||
assert.Equal(t, "/workspaces/project", createdAgent.Directory,
|
||||
"directory should be the container-internal path from devcontainer "+
|
||||
"read-configuration, not the host-side workspace_folder")
|
||||
})
|
||||
|
||||
t.Run("Error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -134,6 +134,33 @@ func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
|
||||
}, ResolvePaths(mcpConfigFile, workingDir)
|
||||
}
|
||||
|
||||
// ContextPartsFromDir reads instruction files and discovers skills
|
||||
// from a specific directory, using default file names. This is used
|
||||
// by the CLI chat context commands to read context from an arbitrary
|
||||
// directory without consulting agent env vars.
|
||||
func ContextPartsFromDir(dir string) []codersdk.ChatMessagePart {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
|
||||
if entry, found := readInstructionFileFromDir(dir, DefaultInstructionsFile); found {
|
||||
parts = append(parts, entry)
|
||||
}
|
||||
|
||||
// Reuse ResolvePaths so CLI skill discovery follows the same
|
||||
// project-relative path handling as agent config resolution.
|
||||
skillParts := discoverSkills(
|
||||
ResolvePaths(strings.Join([]string{DefaultSkillsDir, "skills"}, ","), dir),
|
||||
DefaultSkillMetaFile,
|
||||
)
|
||||
parts = append(parts, skillParts...)
|
||||
|
||||
// Guarantee non-nil slice.
|
||||
if parts == nil {
|
||||
parts = []codersdk.ChatMessagePart{}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// MCPConfigFiles returns the resolved MCP configuration file
|
||||
// paths for the agent's MCP manager.
|
||||
func (api *API) MCPConfigFiles() []string {
|
||||
|
||||
@@ -23,18 +23,144 @@ func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartTyp
|
||||
return out
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
func writeSkillMetaFileInRoot(t *testing.T, skillsRoot, name, description string) string {
|
||||
t.Helper()
|
||||
|
||||
// Clear all env vars so defaults are used.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
skillDir := filepath.Join(skillsRoot, name)
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
return skillDir
|
||||
}
|
||||
|
||||
func writeSkillMetaFile(t *testing.T, dir, name, description string) string {
|
||||
t.Helper()
|
||||
return writeSkillMetaFileInRoot(t, filepath.Join(dir, ".agents", "skills"), name, description)
|
||||
}
|
||||
|
||||
func TestContextPartsFromDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsInstructionFilePart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("project instructions"), 0o600))
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Empty(t, skillParts)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "project instructions", contextParts[0].ContextFileContent)
|
||||
require.False(t, contextParts[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFile(t, dir, "my-skill", "A test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillPartsFromSkillsDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFileInRoot(
|
||||
t,
|
||||
filepath.Join(dir, "skills"),
|
||||
"my-skill",
|
||||
"A test skill",
|
||||
)
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsEmptyForEmptyDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(t.TempDir())
|
||||
|
||||
require.NotNil(t, parts)
|
||||
require.Empty(t, parts)
|
||||
})
|
||||
|
||||
t.Run("ReturnsCombinedResults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("combined instructions"), 0o600))
|
||||
skillDir := writeSkillMetaFile(t, dir, "combined-skill", "Combined test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 2)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "combined instructions", contextParts[0].ContextFileContent)
|
||||
require.Equal(t, "combined-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
})
|
||||
}
|
||||
|
||||
func setupConfigTestEnv(t *testing.T, overrides map[string]string) string {
|
||||
t.Helper()
|
||||
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
for key, value := range overrides {
|
||||
t.Setenv(key, value)
|
||||
}
|
||||
|
||||
return fakeHome
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -46,20 +172,18 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CustomEnvVars", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
optInstructions := t.TempDir()
|
||||
optSkills := t.TempDir()
|
||||
optMCP := platformAbsPath("opt", "mcp.json")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: optInstructions,
|
||||
agentcontextconfig.EnvInstructionsFile: "CUSTOM.md",
|
||||
agentcontextconfig.EnvSkillsDirs: optSkills,
|
||||
agentcontextconfig.EnvSkillMetaFile: "META.yaml",
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
|
||||
// Create files matching the custom names so we can
|
||||
// verify the env vars actually change lookup behavior.
|
||||
@@ -85,15 +209,12 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("WhitespaceInFileNames", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsFile: " CLAUDE.md ",
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// Create a file matching the trimmed name.
|
||||
@@ -106,19 +227,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CommaSeparatedDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
a := t.TempDir()
|
||||
b := t.TempDir()
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: a + "," + b,
|
||||
})
|
||||
|
||||
// Put instruction files in both dirs.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
|
||||
@@ -133,17 +248,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsInstructionFiles", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
|
||||
// Create ~/.coder/AGENTS.md
|
||||
coderDir := filepath.Join(fakeHome, ".coder")
|
||||
@@ -164,16 +272,9 @@ func TestConfig(t *testing.T) {
|
||||
require.False(t, ctxFiles[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
|
||||
// Create AGENTS.md in the working directory.
|
||||
@@ -193,16 +294,9 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
largeContent := strings.Repeat("a", 64*1024+100)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
|
||||
@@ -215,79 +309,47 @@ func TestConfig(t *testing.T) {
|
||||
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
|
||||
})
|
||||
|
||||
t.Run("SanitizesHTMLComments", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
sanitizationTests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SanitizesHTMLComments",
|
||||
input: "visible\n<!-- hidden -->content",
|
||||
expected: "visible\ncontent",
|
||||
},
|
||||
{
|
||||
name: "SanitizesInvisibleUnicode",
|
||||
input: "before\u200bafter",
|
||||
expected: "beforeafter",
|
||||
},
|
||||
{
|
||||
name: "NormalizesCRLF",
|
||||
input: "line1\r\nline2\rline3",
|
||||
expected: "line1\nline2\nline3",
|
||||
},
|
||||
}
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
for _, tt := range sanitizationTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte(tt.input),
|
||||
0o600,
|
||||
))
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("visible\n<!-- hidden -->content"),
|
||||
0o600,
|
||||
))
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// U+200B (zero-width space) should be stripped.
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("before\u200bafter"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("NormalizesCRLF", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("line1\r\nline2\rline3"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, tt.expected, ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DiscoversSkills", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
@@ -320,17 +382,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkipsMissingDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: nonExistent,
|
||||
agentcontextconfig.EnvSkillsDirs: nonExistent,
|
||||
})
|
||||
|
||||
workDir := t.TempDir()
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
@@ -340,17 +398,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, cfg.Parts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
|
||||
optMCP := platformAbsPath("opt", "custom.json")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
|
||||
workDir := t.TempDir()
|
||||
_, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -358,14 +412,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{optMCP}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir := filepath.Join(workDir, "skills")
|
||||
@@ -385,14 +435,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, skillParts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir1 := filepath.Join(workDir, "skills1")
|
||||
|
||||
+1141
-1038
File diff suppressed because it is too large
Load Diff
@@ -98,6 +98,21 @@ message Manifest {
|
||||
repeated WorkspaceApp apps = 11;
|
||||
repeated WorkspaceAgentMetadata.Description metadata = 12;
|
||||
repeated WorkspaceAgentDevcontainer devcontainers = 17;
|
||||
repeated WorkspaceSecret secrets = 19;
|
||||
}
|
||||
|
||||
// WorkspaceSecret is a secret included in the agent manifest
|
||||
// for injection into a workspace.
|
||||
message WorkspaceSecret {
|
||||
// Environment variable name to inject (e.g. "GITHUB_TOKEN").
|
||||
// Empty string means this secret is not injected as an env var.
|
||||
string env_name = 1;
|
||||
// File path to write the secret value to (e.g.
|
||||
// "~/.aws/credentials"). Empty string means this secret is not
|
||||
// written to a file.
|
||||
string file_path = 2;
|
||||
// The decrypted secret value.
|
||||
bytes value = 3;
|
||||
}
|
||||
|
||||
message WorkspaceAgentDevcontainer {
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -620,6 +622,11 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
defer artifact.Reader.Close()
|
||||
defer func() {
|
||||
if artifact.ThumbnailReader != nil {
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if artifact.Size > workspacesdk.MaxRecordingSize {
|
||||
a.logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
@@ -633,10 +640,60 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "video/mp4")
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
|
||||
// Discard the thumbnail if it exceeds the maximum size.
|
||||
// The server-side consumer also enforces this per-part, but
|
||||
// rejecting it here avoids streaming a large thumbnail over
|
||||
// the wire for nothing.
|
||||
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
|
||||
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.ThumbnailSize),
|
||||
slog.F("max_size", workspacesdk.MaxThumbnailSize),
|
||||
)
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
artifact.ThumbnailReader = nil
|
||||
artifact.ThumbnailSize = 0
|
||||
}
|
||||
|
||||
// The multipart response is best-effort: once WriteHeader(200) is
|
||||
// called, CreatePart failures produce a truncated response without
|
||||
// the closing boundary. The server-side consumer handles this
|
||||
// gracefully, preserving any parts read before the error.
|
||||
mw := multipart.NewWriter(rw)
|
||||
defer mw.Close()
|
||||
rw.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(rw, artifact.Reader)
|
||||
|
||||
// Part 1: video/mp4 (always present).
|
||||
videoPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
|
||||
a.logger.Warn(ctx, "failed to write video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Part 2: image/jpeg (present only when thumbnail was extracted).
|
||||
if artifact.ThumbnailReader != nil {
|
||||
thumbPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"image/jpeg"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create thumbnail multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
_, _ = io.Copy(thumbPart, artifact.ThumbnailReader)
|
||||
}
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
|
||||
@@ -4,12 +4,17 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -59,6 +64,8 @@ type fakeDesktop struct {
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
|
||||
thumbnailData []byte // if set, StopRecording includes a thumbnail
|
||||
|
||||
// Recording tracking (guarded by recMu).
|
||||
recMu sync.Mutex
|
||||
recordings map[string]string // ID → file path
|
||||
@@ -187,10 +194,15 @@ func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*age
|
||||
_ = file.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
artifact := &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
if f.thumbnailData != nil {
|
||||
artifact.ThumbnailReader = io.NopCloser(bytes.NewReader(f.thumbnailData))
|
||||
artifact.ThumbnailSize = int64(len(f.thumbnailData))
|
||||
}
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) RecordActivity() {
|
||||
@@ -785,8 +797,8 @@ func TestRecordingStartStop(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStartFails(t *testing.T) {
|
||||
@@ -847,8 +859,8 @@ func TestRecordingStartIdempotent(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStopIdempotent(t *testing.T) {
|
||||
@@ -872,7 +884,7 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop twice - both should succeed with identical data.
|
||||
var bodies [2][]byte
|
||||
var videoParts [2][]byte
|
||||
for i := range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
@@ -880,10 +892,10 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(recorder, request)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
|
||||
bodies[i] = recorder.Body.Bytes()
|
||||
parts := parseMultipartParts(t, recorder.Header().Get("Content-Type"), recorder.Body.Bytes())
|
||||
videoParts[i] = parts["video/mp4"]
|
||||
}
|
||||
assert.Equal(t, bodies[0], bodies[1])
|
||||
assert.Equal(t, videoParts[0], videoParts[1])
|
||||
}
|
||||
|
||||
func TestRecordingStopInvalidIDFormat(t *testing.T) {
|
||||
@@ -1004,8 +1016,8 @@ func TestRecordingMultipleSimultaneous(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, expected[id], rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, expected[id], parts["video/mp4"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1112,8 +1124,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
firstData := rr.Body.Bytes()
|
||||
firstParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
firstData := firstParts["video/mp4"]
|
||||
require.NotEmpty(t, firstData)
|
||||
|
||||
// Step 3: Start again with the same ID - should succeed
|
||||
@@ -1128,8 +1140,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
secondData := rr.Body.Bytes()
|
||||
secondParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
secondData := secondParts["video/mp4"]
|
||||
require.NotEmpty(t, secondData)
|
||||
|
||||
// The two recordings should have different data because the
|
||||
@@ -1235,3 +1247,166 @@ func TestRecordingStopCorrupted(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording is corrupted.", respStop.Message)
|
||||
}
|
||||
|
||||
// parseMultipartParts parses a multipart/mixed response and returns
|
||||
// a map from Content-Type to body bytes.
|
||||
func parseMultipartParts(t *testing.T, contentType string, body []byte) map[string][]byte {
|
||||
t.Helper()
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
require.NoError(t, err, "parse Content-Type")
|
||||
boundary := params["boundary"]
|
||||
require.NotEmpty(t, boundary, "missing boundary")
|
||||
mr := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
parts := make(map[string][]byte)
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err, "unexpected multipart parse error")
|
||||
ct := part.Header.Get("Content-Type")
|
||||
data, readErr := io.ReadAll(part)
|
||||
require.NoError(t, readErr)
|
||||
parts[ct] = data
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create a fake JPEG header: 0xFF 0xD8 0xFF followed by 509 zero bytes.
|
||||
thumbnail := make([]byte, 512)
|
||||
thumbnail[0] = 0xff
|
||||
thumbnail[1] = 0xd8
|
||||
thumbnail[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: thumbnail,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 2, "expected exactly 2 parts (video + thumbnail)")
|
||||
|
||||
// The fake writes "fake-mp4-data-<id>-<counter>" as the MP4 content.
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
assert.Equal(t, thumbnail, parts["image/jpeg"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_NoThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_OversizedThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create thumbnail data that exceeds MaxThumbnailSize.
|
||||
oversizedThumb := make([]byte, workspacesdk.MaxThumbnailSize+1)
|
||||
oversizedThumb[0] = 0xff
|
||||
oversizedThumb[1] = 0xd8
|
||||
oversizedThumb[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: oversizedThumb,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response contains only the video part.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only, oversized thumbnail discarded)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
@@ -105,6 +105,11 @@ type RecordingArtifact struct {
|
||||
Reader io.ReadCloser
|
||||
// Size is the byte length of the MP4 content.
|
||||
Size int64
|
||||
// ThumbnailReader is the JPEG thumbnail. May be nil if no
|
||||
// thumbnail was produced. Callers must close it when done.
|
||||
ThumbnailReader io.ReadCloser
|
||||
// ThumbnailSize is the byte length of the thumbnail.
|
||||
ThumbnailSize int64
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
|
||||
@@ -56,6 +56,7 @@ type screenshotOutput struct {
|
||||
type recordingProcess struct {
|
||||
cmd *exec.Cmd
|
||||
filePath string
|
||||
thumbPath string
|
||||
stopped bool
|
||||
killed bool // true when the process was SIGKILLed
|
||||
done chan struct{} // closed when cmd.Wait() returns
|
||||
@@ -383,13 +384,20 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old recording file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old thumbnail file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, recordingID)
|
||||
}
|
||||
|
||||
@@ -406,6 +414,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".thumb.jpg")
|
||||
|
||||
// Use a background context so the process outlives the HTTP
|
||||
// request that triggered it.
|
||||
@@ -419,6 +428,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
"--idle-speedup", "20",
|
||||
"--idle-min-duration", "0.35",
|
||||
"--idle-noise-tolerance", "-38dB",
|
||||
"--thumbnail", thumbPath,
|
||||
filePath)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
@@ -427,9 +437,10 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
rec := &recordingProcess{
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
done: make(chan struct{}),
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
thumbPath: thumbPath,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
rec.waitErr = cmd.Wait()
|
||||
@@ -499,10 +510,35 @@ func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string)
|
||||
_ = f.Close()
|
||||
return nil, xerrors.Errorf("stat recording artifact: %w", err)
|
||||
}
|
||||
return &RecordingArtifact{
|
||||
artifact := &RecordingArtifact{
|
||||
Reader: f,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
// Attach thumbnail if the subprocess wrote one.
|
||||
thumbFile, err := os.Open(rec.thumbPath)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "thumbnail not available",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
thumbInfo, err := thumbFile.Stat()
|
||||
if err != nil {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail stat failed",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
if thumbInfo.Size() == 0 {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail file is empty",
|
||||
slog.F("thumbnail_path", rec.thumbPath))
|
||||
return artifact, nil
|
||||
}
|
||||
artifact.ThumbnailReader = thumbFile
|
||||
artifact.ThumbnailSize = thumbInfo.Size()
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
// lockedStopRecordingProcess stops a single recording via stopOnce.
|
||||
@@ -571,18 +607,33 @@ func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
|
||||
}
|
||||
info, err := os.Stat(rec.filePath)
|
||||
if err != nil {
|
||||
// File already removed or inaccessible; drop entry.
|
||||
// File already removed or inaccessible; clean up
|
||||
// any leftover thumbnail and drop the entry.
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
continue
|
||||
}
|
||||
if p.clock.Since(info.ModTime()) > time.Hour {
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale recording file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
}
|
||||
@@ -603,13 +654,14 @@ func (p *portableDesktop) Close() error {
|
||||
// Snapshot recording file paths and idle goroutine channels
|
||||
// for cleanup, then clear the map.
|
||||
type recEntry struct {
|
||||
id string
|
||||
filePath string
|
||||
idleDone chan struct{}
|
||||
id string
|
||||
filePath string
|
||||
thumbPath string
|
||||
idleDone chan struct{}
|
||||
}
|
||||
var allRecs []recEntry
|
||||
for id, rec := range p.recordings {
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, thumbPath: rec.thumbPath, idleDone: rec.idleDone})
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
session := p.session
|
||||
@@ -630,13 +682,20 @@ func (p *portableDesktop) Close() error {
|
||||
go func() {
|
||||
defer close(cleanupDone)
|
||||
for _, entry := range allRecs {
|
||||
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(entry.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove recording file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("file_path", entry.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(entry.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove thumbnail file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("thumbnail_path", entry.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
session.cancel()
|
||||
|
||||
@@ -2,6 +2,7 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -584,6 +585,7 @@ func TestPortableDesktop_StartRecording(t *testing.T) {
|
||||
joined := strings.Join(cmd, " ")
|
||||
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
|
||||
found = true
|
||||
assert.Contains(t, joined, "--thumbnail", "record command should include --thumbnail flag")
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -666,6 +668,66 @@ func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
|
||||
defer artifact.Reader.Close()
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// No thumbnail file exists, so ThumbnailReader should be nil.
|
||||
assert.Nil(t, artifact.ThumbnailReader, "ThumbnailReader should be nil when no thumbnail file exists")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write a dummy MP4 file at the expected path.
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(filePath) })
|
||||
|
||||
// Write a thumbnail file at the expected path.
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".thumb.jpg")
|
||||
thumbContent := []byte("fake-jpeg-thumbnail")
|
||||
require.NoError(t, os.WriteFile(thumbPath, thumbContent, 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(thumbPath) })
|
||||
|
||||
artifact, err := pd.StopRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
defer artifact.Reader.Close()
|
||||
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// Thumbnail should be attached.
|
||||
require.NotNil(t, artifact.ThumbnailReader, "ThumbnailReader should be non-nil when thumbnail file exists")
|
||||
defer artifact.ThumbnailReader.Close()
|
||||
assert.Equal(t, int64(len(thumbContent)), artifact.ThumbnailSize)
|
||||
|
||||
// Read and verify thumbnail content.
|
||||
thumbData, err := io.ReadAll(artifact.ThumbnailReader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, thumbContent, thumbData)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
@@ -750,12 +812,18 @@ func TestPortableDesktop_IdleTimeout_StopsRecordings(t *testing.T) {
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout to trigger the stop-all.
|
||||
clk.Advance(idleTimeout)
|
||||
clk.Advance(idleTimeout).MustWait(ctx)
|
||||
|
||||
// Wait for the stop timer to be created, then release it.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Advance past the 15s stop timeout so the process is
|
||||
// forcibly killed. Without this the test depends on the real
|
||||
// shell handling SIGINT promptly, which is unreliable on
|
||||
// macOS CI runners (the flake in #1461).
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
|
||||
// The recording process should now be stopped.
|
||||
require.Eventually(t, func() bool {
|
||||
pd.mu.Lock()
|
||||
@@ -877,11 +945,17 @@ func TestPortableDesktop_IdleTimeout_MultipleRecordings(t *testing.T) {
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout.
|
||||
clk.Advance(idleTimeout)
|
||||
clk.Advance(idleTimeout).MustWait(ctx)
|
||||
|
||||
// Wait for both stop timers.
|
||||
// Each idle monitor goroutine serializes on p.mu, so the
|
||||
// second stop timer is only created after the first stop
|
||||
// completes. Advance past the 15s stop timeout after each
|
||||
// release so the process is forcibly killed instead of
|
||||
// depending on SIGINT (unreliable on macOS — see #1461).
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Both recordings should be stopped.
|
||||
|
||||
@@ -87,6 +87,12 @@ func IsDevVersion(v string) bool {
|
||||
return strings.Contains(v, "-"+develPreRelease)
|
||||
}
|
||||
|
||||
// IsRCVersion returns true if the version has a release candidate
|
||||
// pre-release tag, e.g. "v2.31.0-rc.0".
|
||||
func IsRCVersion(v string) bool {
|
||||
return strings.Contains(v, "-rc.")
|
||||
}
|
||||
|
||||
// IsDev returns true if this is a development build.
|
||||
// CI builds are also considered development builds.
|
||||
func IsDev() bool {
|
||||
|
||||
@@ -102,3 +102,29 @@ func TestBuildInfo(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRCVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected bool
|
||||
}{
|
||||
{"RC0", "v2.31.0-rc.0", true},
|
||||
{"RC1WithBuild", "v2.31.0-rc.1+abc123", true},
|
||||
{"RC10", "v2.31.0-rc.10", true},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", true},
|
||||
{"DevelVersion", "v2.31.0-devel+abc123", false},
|
||||
{"StableVersion", "v2.31.0", false},
|
||||
{"DevNoVersion", "v0.0.0-devel+abc123", false},
|
||||
{"BetaVersion", "v2.31.0-beta.1", false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, buildinfo.IsRCVersion(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+194
@@ -0,0 +1,194 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) chatCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "chat",
|
||||
Short: "Manage agent chats",
|
||||
Long: "Commands for interacting with chats from within a workspace.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) chatContextCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "context",
|
||||
Short: "Manage chat context",
|
||||
Long: "Add or clear context files and skills for an active chat session.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextAddCommand(),
|
||||
r.chatContextClearCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextAddCommand() *serpent.Command {
|
||||
var (
|
||||
dir string
|
||||
chatID string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "add",
|
||||
Short: "Add context to an active chat",
|
||||
Long: "Read instruction files and discover skills from a directory, then add " +
|
||||
"them as context to an active chat session. Multiple calls " +
|
||||
"are additive.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
if dir == "" && inv.Environ.Get("CODER") != "true" {
|
||||
return xerrors.New("this command must be run inside a Coder workspace (set --dir to override)")
|
||||
}
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedDir := dir
|
||||
if resolvedDir == "" {
|
||||
resolvedDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get working directory: %w", err)
|
||||
}
|
||||
}
|
||||
resolvedDir, err = filepath.Abs(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve directory: %w", err)
|
||||
}
|
||||
info, err := os.Stat(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cannot read directory %q: %w", resolvedDir, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return xerrors.Errorf("%q is not a directory", resolvedDir)
|
||||
}
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(resolvedDir)
|
||||
if len(parts) == 0 {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "No context files or skills found in "+resolvedDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve chat ID from flag or auto-detect.
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.AddChatContext(ctx, agentsdk.AddChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
Parts: parts,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("add chat context: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Added %d context part(s) to chat %s\n", resp.Count, resp.ChatID)
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "Directory",
|
||||
Flag: "dir",
|
||||
Description: "Directory to read context files and skills from. Defaults to the current working directory.",
|
||||
Value: serpent.StringOf(&dir),
|
||||
},
|
||||
{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to add context to. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
},
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextClearCommand() *serpent.Command {
|
||||
var chatID string
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "clear",
|
||||
Short: "Clear context from an active chat",
|
||||
Long: "Soft-delete all context-file and skill messages from an active chat. " +
|
||||
"The next turn will re-fetch default context from the agent.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear chat context: %w", err)
|
||||
}
|
||||
|
||||
if resp.ChatID == uuid.Nil {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "No active chats to clear.")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Cleared context from chat %s\n", resp.ChatID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to clear context from. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
}},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// parseChatID returns the chat UUID from the flag value (which
|
||||
// serpent already populates from --chat or CODER_CHAT_ID). Returns
|
||||
// uuid.Nil if empty (the server will auto-detect).
|
||||
func parseChatID(flagValue string) (uuid.UUID, error) {
|
||||
if flagValue == "" {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
parsed, err := uuid.Parse(flagValue)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("invalid chat ID %q: %w", flagValue, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
)
|
||||
|
||||
func TestExpChatContextAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RequiresWorkspaceOrDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
|
||||
err := inv.Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
})
|
||||
|
||||
t.Run("AllowsExplicitDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add", "--dir", t.TempDir())
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AllowsWorkspaceEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
inv.Environ.Set("CODER", "true")
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
+30
-5
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -102,6 +103,7 @@ func (r *RootCmd) CoreSubcommands() []*serpent.Command {
|
||||
r.portForward(),
|
||||
r.publickey(),
|
||||
r.resetPassword(),
|
||||
r.secrets(),
|
||||
r.sharing(),
|
||||
r.state(),
|
||||
r.tasksCommand(),
|
||||
@@ -148,6 +150,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
|
||||
return []*serpent.Command{
|
||||
r.scaletestCmd(),
|
||||
r.errorExample(),
|
||||
r.chatCommand(),
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
@@ -710,7 +713,7 @@ func (r *RootCmd) createHTTPClient(ctx context.Context, serverURL *url.URL, inv
|
||||
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||
transport = wrapTransportWithUserAgentHeader(transport, inv)
|
||||
if !r.noVersionCheck {
|
||||
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
transport = wrapTransportWithVersionCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
// Create a new client without any wrapped transport
|
||||
// otherwise it creates an infinite loop!
|
||||
basicClient := codersdk.New(serverURL)
|
||||
@@ -1434,6 +1437,21 @@ func defaultUpgradeMessage(version string) string {
|
||||
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
|
||||
}
|
||||
|
||||
// serverVersionMessage returns a warning message if the server version
|
||||
// is a release candidate or development build. Returns empty string
|
||||
// for stable versions. RC is checked before devel because RC dev
|
||||
// builds (e.g. v2.33.0-rc.1-devel+hash) contain both tags.
|
||||
func serverVersionMessage(serverVersion string) string {
|
||||
switch {
|
||||
case buildinfo.IsRCVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a release candidate of Coder (%s)", serverVersion)
|
||||
case buildinfo.IsDevVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a development version of Coder (%s)", serverVersion)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
|
||||
// that checks for entitlement warnings and prints them to the user.
|
||||
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
|
||||
@@ -1452,10 +1470,10 @@ func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.
|
||||
})
|
||||
}
|
||||
|
||||
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
|
||||
// that checks for version mismatches between the client and server. If a mismatch
|
||||
// is detected, a warning is printed to the user.
|
||||
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
// wrapTransportWithVersionCheck adds a middleware to the HTTP transport
|
||||
// that checks the server version and warns about development builds,
|
||||
// release candidates, and client/server version mismatches.
|
||||
func wrapTransportWithVersionCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
var once sync.Once
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.RoundTrip(req)
|
||||
@@ -1467,9 +1485,16 @@ func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.In
|
||||
if serverVersion == "" {
|
||||
return
|
||||
}
|
||||
// Warn about non-stable server versions. Skip
|
||||
// during tests to avoid polluting golden files.
|
||||
if msg := serverVersionMessage(serverVersion); msg != "" && flag.Lookup("test.v") == nil {
|
||||
warning := pretty.Sprint(cliui.DefaultStyles.Warn, msg)
|
||||
_, _ = fmt.Fprintln(inv.Stderr, warning)
|
||||
}
|
||||
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
|
||||
if serverInfo, err := getBuildInfo(inv.Context()); err == nil {
|
||||
switch {
|
||||
|
||||
@@ -91,7 +91,7 @@ func Test_formatExamples(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoOutput", func(t *testing.T) {
|
||||
@@ -102,7 +102,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -131,7 +131,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
expectedUpgradeMessage := "My custom upgrade message"
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -159,6 +159,53 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
||||
require.Equal(t, expectedOutput, buf.String())
|
||||
})
|
||||
|
||||
t.Run("ServerStableVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &RootCmd{}
|
||||
cmd, err := r.Command(nil)
|
||||
require.NoError(t, err)
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
codersdk.BuildVersionHeader: []string{"v2.31.0"},
|
||||
},
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), inv, "v2.31.0", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Empty(t, buf.String())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_serverVersionMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected string
|
||||
}{
|
||||
{"Stable", "v2.31.0", ""},
|
||||
{"Dev", "v0.0.0-devel+abc123", "the server is running a development version of Coder (v0.0.0-devel+abc123)"},
|
||||
{"RC", "v2.31.0-rc.1", "the server is running a release candidate of Coder (v2.31.0-rc.1)"},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", "the server is running a release candidate of Coder (v2.33.0-rc.1-devel+727ec00f7)"},
|
||||
{"Empty", "", ""},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, serverVersionMessage(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
|
||||
|
||||
+340
@@ -0,0 +1,340 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/dustin/go-humanize"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) secrets() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "secret",
|
||||
Aliases: []string{"secrets"},
|
||||
Short: "Manage personal secrets",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "Create a secret",
|
||||
Command: "coder secret create openai-key --value \"$SECRET_VALUE\" --description \"Personal OPENAI_API key\" --inject-env OPEN_AI_KEY --inject-file \"~/.openai-key\"",
|
||||
},
|
||||
Example{
|
||||
Description: "Update a secret",
|
||||
Command: "coder secret update openai-key --value \"$NEW_SECRET_VALUE\" --description \"Updated description\" --inject-env NEW_ENV_NAME --inject-file \"~/.new-path\"",
|
||||
},
|
||||
Example{
|
||||
Description: "List your secrets",
|
||||
Command: "coder secret list",
|
||||
},
|
||||
Example{
|
||||
Description: "Show a specific secret",
|
||||
Command: "coder secret list openai-key",
|
||||
},
|
||||
Example{
|
||||
Description: "Delete a secret",
|
||||
Command: "coder secret delete openai-key",
|
||||
},
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
return inv.Command.HelpHandler(inv)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.secretCreate(),
|
||||
r.secretUpdate(),
|
||||
r.secretList(),
|
||||
r.secretDelete(),
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) secretCreate() *serpent.Command {
|
||||
var (
|
||||
value string
|
||||
description string
|
||||
injectEnv string
|
||||
injectFile string
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "create <name>",
|
||||
Short: "Create a secret",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "value",
|
||||
Flag: "value",
|
||||
Description: "Set the secret value. This flag is required.",
|
||||
Value: serpent.StringOf(&value),
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "description",
|
||||
Flag: "description",
|
||||
Description: "Set the secret description.",
|
||||
Value: serpent.StringOf(&description),
|
||||
},
|
||||
{
|
||||
Name: "inject-env",
|
||||
Flag: "inject-env",
|
||||
Description: "Inject the secret into workspaces as an environment variable.",
|
||||
Value: serpent.StringOf(&injectEnv),
|
||||
},
|
||||
{
|
||||
Name: "inject-file",
|
||||
Flag: "inject-file",
|
||||
Description: "Inject the secret into workspaces as a file.",
|
||||
Value: serpent.StringOf(&injectFile),
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
secret, err := client.CreateUserSecret(inv.Context(), codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: inv.Args[0],
|
||||
Value: value,
|
||||
Description: description,
|
||||
EnvName: injectEnv,
|
||||
FilePath: injectFile,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create secret %q: %w", inv.Args[0], err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Created secret %s.\n", cliui.Keyword(secret.Name))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) secretUpdate() *serpent.Command {
|
||||
var (
|
||||
value string
|
||||
description string
|
||||
injectEnv string
|
||||
injectFile string
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "update <name>",
|
||||
Short: "Update a secret",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "value",
|
||||
Flag: "value",
|
||||
Description: "Update the secret value.",
|
||||
Value: serpent.StringOf(&value),
|
||||
},
|
||||
{
|
||||
Name: "description",
|
||||
Flag: "description",
|
||||
Description: "Update the secret description. Pass an empty string to clear it.",
|
||||
Value: serpent.StringOf(&description),
|
||||
},
|
||||
{
|
||||
Name: "inject-env",
|
||||
Flag: "inject-env",
|
||||
Description: "Update the environment variable injection target. Pass an empty string to clear it.",
|
||||
Value: serpent.StringOf(&injectEnv),
|
||||
},
|
||||
{
|
||||
Name: "inject-file",
|
||||
Flag: "inject-file",
|
||||
Description: "Update the file injection target. Pass an empty string to clear it.",
|
||||
Value: serpent.StringOf(&injectFile),
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := codersdk.UpdateUserSecretRequest{}
|
||||
if userSetOption(inv, "value") {
|
||||
req.Value = &value
|
||||
}
|
||||
if userSetOption(inv, "description") {
|
||||
req.Description = &description
|
||||
}
|
||||
if userSetOption(inv, "inject-env") {
|
||||
req.EnvName = &injectEnv
|
||||
}
|
||||
if userSetOption(inv, "inject-file") {
|
||||
req.FilePath = &injectFile
|
||||
}
|
||||
|
||||
secret, err := client.UpdateUserSecret(inv.Context(), codersdk.Me, inv.Args[0], req)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update secret %q: %w", inv.Args[0], err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Updated secret %s.\n", cliui.Keyword(secret.Name))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
type secretListRow struct {
|
||||
codersdk.UserSecret `table:"-"`
|
||||
|
||||
Name string `json:"-" table:"name,default_sort"`
|
||||
Updated string `json:"-" table:"updated"`
|
||||
Env string `json:"-" table:"env"`
|
||||
File string `json:"-" table:"file"`
|
||||
Description string `json:"-" table:"description"`
|
||||
}
|
||||
|
||||
func secretListRowFromSecret(secret codersdk.UserSecret) secretListRow {
|
||||
return secretListRow{
|
||||
UserSecret: secret,
|
||||
Name: secret.Name,
|
||||
Updated: humanize.Time(secret.UpdatedAt),
|
||||
Env: secret.EnvName,
|
||||
File: secret.FilePath,
|
||||
Description: secret.Description,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) secretList() *serpent.Command {
|
||||
formatter := cliui.NewOutputFormatter(
|
||||
cliui.ChangeFormatterData(
|
||||
cliui.TableFormat(
|
||||
[]secretListRow{},
|
||||
[]string{"name", "updated", "env", "file", "description"},
|
||||
),
|
||||
func(data any) (any, error) {
|
||||
switch rows := data.(type) {
|
||||
case []secretListRow:
|
||||
return rows, nil
|
||||
case secretListRow:
|
||||
return []secretListRow{rows}, nil
|
||||
default:
|
||||
return nil, xerrors.Errorf("expected []secretListRow or secretListRow, got %T", data)
|
||||
}
|
||||
},
|
||||
),
|
||||
cliui.ChangeFormatterData(
|
||||
cliui.JSONFormat(),
|
||||
func(data any) (any, error) {
|
||||
switch rows := data.(type) {
|
||||
case []secretListRow:
|
||||
secrets := make([]codersdk.UserSecret, len(rows))
|
||||
for i := range rows {
|
||||
secrets[i] = rows[i].UserSecret
|
||||
}
|
||||
return secrets, nil
|
||||
case secretListRow:
|
||||
return []codersdk.UserSecret{rows.UserSecret}, nil
|
||||
default:
|
||||
return nil, xerrors.Errorf("expected []secretListRow or secretListRow, got %T", data)
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "list [name]",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List secrets, or show one by name",
|
||||
Middleware: serpent.RequireRangeArgs(0, 1),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var data any
|
||||
if len(inv.Args) == 1 {
|
||||
secret, err := client.UserSecretByName(inv.Context(), codersdk.Me, inv.Args[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get secret %q: %w", inv.Args[0], err)
|
||||
}
|
||||
data = secretListRowFromSecret(secret)
|
||||
} else {
|
||||
secrets, err := client.UserSecrets(inv.Context(), codersdk.Me)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list secrets: %w", err)
|
||||
}
|
||||
|
||||
rows := make([]secretListRow, len(secrets))
|
||||
for i := range secrets {
|
||||
rows[i] = secretListRowFromSecret(secrets[i])
|
||||
}
|
||||
data = rows
|
||||
}
|
||||
|
||||
out, err := formatter.Format(inv.Context(), data)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format secrets: %w", err)
|
||||
}
|
||||
if out == "" {
|
||||
cliui.Infof(inv.Stderr, "No secrets found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintln(inv.Stdout, out)
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
formatter.AttachOptions(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) secretDelete() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "delete <name>",
|
||||
Aliases: []string{"remove"},
|
||||
Short: "Delete a secret",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
cliui.SkipPromptOption(),
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := inv.Args[0]
|
||||
_, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: fmt.Sprintf("Delete secret %s?", pretty.Sprint(cliui.DefaultStyles.Code, name)),
|
||||
IsConfirm: true,
|
||||
Default: cliui.ConfirmNo,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = client.DeleteUserSecret(inv.Context(), codersdk.Me, name); err != nil {
|
||||
return xerrors.Errorf("delete secret %q: %w", name, err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Deleted secret %s at %s.\n", cliui.Keyword(name), cliui.Timestamp(time.Now()))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSecretCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("MissingValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
inv, root := clitest.New(t, "secret", "create", "openai-key")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "Missing values for the required flags: value")
|
||||
})
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
inv, root := clitest.New(
|
||||
t,
|
||||
"secret",
|
||||
"create",
|
||||
"openai-key",
|
||||
"--value", "super-secret-value",
|
||||
"--description", "Personal OPENAI_API key",
|
||||
"--inject-env", "OPEN_AI_KEY",
|
||||
"--inject-file", "~/.openai-key",
|
||||
)
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "openai-key")
|
||||
|
||||
secret, err := client.UserSecretByName(ctx, codersdk.Me, "openai-key")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "openai-key", secret.Name)
|
||||
require.Equal(t, "Personal OPENAI_API key", secret.Description)
|
||||
require.Equal(t, "OPEN_AI_KEY", secret.EnvName)
|
||||
require.Equal(t, "~/.openai-key", secret.FilePath)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecretUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ServerValidationError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err := client.CreateUserSecret(setupCtx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "my-secret",
|
||||
Value: "original-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(t, "secret", "update", "my-secret")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "At least one field must be provided")
|
||||
})
|
||||
|
||||
t.Run("AllowsClearingFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err := client.CreateUserSecret(setupCtx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "my-secret",
|
||||
Value: "original-value",
|
||||
Description: "original description",
|
||||
EnvName: "MY_SECRET",
|
||||
FilePath: "~/.my-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(
|
||||
t,
|
||||
"secret",
|
||||
"update",
|
||||
"my-secret",
|
||||
"--value", "rotated-secret",
|
||||
"--description", "",
|
||||
"--inject-env", "",
|
||||
"--inject-file", "",
|
||||
)
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "my-secret")
|
||||
|
||||
secret, err := client.UserSecretByName(ctx, codersdk.Me, "my-secret")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", secret.Description)
|
||||
require.Equal(t, "", secret.EnvName)
|
||||
require.Equal(t, "", secret.FilePath)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecretList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("TableOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err := client.CreateUserSecret(setupCtx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "aws-creds",
|
||||
Value: "aws-value",
|
||||
Description: "AWS credentials",
|
||||
FilePath: "~/.aws/creds",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = client.CreateUserSecret(setupCtx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
Description: "Personal GitHub access token",
|
||||
EnvName: "GITHUB_TOKEN",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(t, "secret", "list")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
out := output.Stdout()
|
||||
assert.Contains(t, out, "NAME")
|
||||
assert.Contains(t, out, "UPDATED")
|
||||
assert.Contains(t, out, "ENV")
|
||||
assert.Contains(t, out, "FILE")
|
||||
assert.Contains(t, out, "DESCRIPTION")
|
||||
assert.Contains(t, out, "github-token")
|
||||
assert.Contains(t, out, "GITHUB_TOKEN")
|
||||
assert.Contains(t, out, "aws-creds")
|
||||
assert.Contains(t, out, "~/.aws/creds")
|
||||
})
|
||||
|
||||
t.Run("JSONOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitMedium)
|
||||
created, err := client.CreateUserSecret(setupCtx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
Description: "Personal GitHub access token",
|
||||
EnvName: "GITHUB_TOKEN",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(t, "secret", "list", "--output=json")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
var got []codersdk.UserSecret
|
||||
require.NoError(t, json.Unmarshal([]byte(output.Stdout()), &got))
|
||||
require.Len(t, got, 1)
|
||||
require.Equal(t, created, got[0])
|
||||
})
|
||||
|
||||
t.Run("SingleSecretTableOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err := client.CreateUserSecret(setupCtx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "aws-creds",
|
||||
Value: "aws-value",
|
||||
Description: "AWS credentials",
|
||||
FilePath: "~/.aws/creds",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = client.CreateUserSecret(setupCtx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
Description: "Personal GitHub access token",
|
||||
EnvName: "GITHUB_TOKEN",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(t, "secret", "list", "github-token")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
out := output.Stdout()
|
||||
assert.Contains(t, out, "NAME")
|
||||
assert.Contains(t, out, "UPDATED")
|
||||
assert.Contains(t, out, "ENV")
|
||||
assert.Contains(t, out, "FILE")
|
||||
assert.Contains(t, out, "DESCRIPTION")
|
||||
assert.Contains(t, out, "github-token")
|
||||
assert.Contains(t, out, "GITHUB_TOKEN")
|
||||
assert.NotContains(t, out, "aws-creds")
|
||||
assert.NotContains(t, out, "~/.aws/creds")
|
||||
})
|
||||
|
||||
t.Run("SingleSecretJSONOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitMedium)
|
||||
created, err := client.CreateUserSecret(setupCtx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
Description: "Personal GitHub access token",
|
||||
EnvName: "GITHUB_TOKEN",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(t, "secret", "list", "github-token", "--output=json")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
var got []codersdk.UserSecret
|
||||
require.NoError(t, json.Unmarshal([]byte(output.Stdout()), &got))
|
||||
require.Len(t, got, 1)
|
||||
require.Equal(t, created, got[0])
|
||||
})
|
||||
|
||||
t.Run("EmptyState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
inv, root := clitest.New(t, "secret", "list")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, output.Stderr(), "No secrets found.")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecretDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err := client.CreateUserSecret(setupCtx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(t, "secret", "delete", "github-token")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
inv = inv.WithContext(ctx)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
waiter := clitest.StartWithWaiter(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "Delete secret")
|
||||
pty.ExpectMatchContext(ctx, "github-token")
|
||||
pty.WriteLine("yes")
|
||||
pty.ExpectMatchContext(ctx, "Deleted secret")
|
||||
|
||||
require.NoError(t, waiter.Wait())
|
||||
|
||||
_, err = client.UserSecretByName(setupCtx, codersdk.Me, "github-token")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
inv, root := clitest.New(t, "secret", "delete", "missing-secret")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
inv = inv.WithContext(ctx)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
waiter := clitest.StartWithWaiter(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "Delete secret")
|
||||
pty.ExpectMatchContext(ctx, "missing-secret")
|
||||
pty.WriteLine("yes")
|
||||
|
||||
err := waiter.Wait()
|
||||
require.ErrorContains(t, err, `delete secret "missing-secret"`)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
+6
-4
@@ -69,15 +69,17 @@ var (
|
||||
// isRetryableError checks for transient connection errors worth
|
||||
// retrying: DNS failures, connection refused, and server 5xx.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
if err == nil || xerrors.Is(err, context.Canceled) {
|
||||
return false
|
||||
}
|
||||
// Check connection errors before context.DeadlineExceeded because
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both.
|
||||
if codersdk.IsConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
if xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
return sdkErr.StatusCode() >= 500
|
||||
|
||||
@@ -516,6 +516,23 @@ func TestIsRetryableError(t *testing.T) {
|
||||
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
|
||||
})
|
||||
}
|
||||
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both
|
||||
// IsConnectionError and context.DeadlineExceeded. Verify it is retryable.
|
||||
t.Run("DialTimeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
|
||||
defer cancel()
|
||||
<-ctx.Done() // ensure deadline has fired
|
||||
_, err := (&net.Dialer{}).DialContext(ctx, "tcp", "127.0.0.1:1")
|
||||
require.Error(t, err)
|
||||
// Proves the ambiguity: this error matches BOTH checks.
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorAs(t, err, new(*net.OpError))
|
||||
assert.True(t, isRetryableError(err))
|
||||
// Also when wrapped, as runCoderConnectStdio does.
|
||||
assert.True(t, isRetryableError(xerrors.Errorf("dial coder connect: %w", err)))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryWithInterval(t *testing.T) {
|
||||
|
||||
Vendored
+1
@@ -43,6 +43,7 @@ SUBCOMMANDS:
|
||||
password
|
||||
restart Restart a workspace
|
||||
schedule Schedule automated start and stop times for workspaces
|
||||
secret Manage personal secrets
|
||||
server Start a Coder server
|
||||
show Display details of a workspace's resources and agents
|
||||
speedtest Run upload and download tests from your machine to a
|
||||
|
||||
+1
-1
@@ -11,7 +11,7 @@ OPTIONS:
|
||||
-O, --org string, $CODER_ORGANIZATION
|
||||
Select which organization (uuid or name) to use.
|
||||
|
||||
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
|
||||
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|workspace build transition|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
|
||||
Columns to display in table output.
|
||||
|
||||
-i, --initiator string, $CODER_PROVISIONER_JOB_LIST_INITIATOR
|
||||
|
||||
@@ -58,7 +58,8 @@
|
||||
"template_display_name": "",
|
||||
"template_icon": "",
|
||||
"workspace_id": "===========[workspace ID]===========",
|
||||
"workspace_name": "test-workspace"
|
||||
"workspace_name": "test-workspace",
|
||||
"workspace_build_transition": "start"
|
||||
},
|
||||
"logs_overflowed": false,
|
||||
"organization_name": "Coder"
|
||||
|
||||
+41
@@ -0,0 +1,41 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder secret
|
||||
|
||||
Manage personal secrets
|
||||
|
||||
Aliases: secrets
|
||||
|
||||
- Create a secret:
|
||||
|
||||
$ coder secret create openai-key --value "$SECRET_VALUE" --description
|
||||
"Personal OPENAI_API key" --inject-env OPEN_AI_KEY --inject-file
|
||||
"~/.openai-key"
|
||||
|
||||
- Update a secret:
|
||||
|
||||
$ coder secret update openai-key --value "$NEW_SECRET_VALUE"
|
||||
--description "Updated description" --inject-env NEW_ENV_NAME --inject-file
|
||||
"~/.new-path"
|
||||
|
||||
- List your secrets:
|
||||
|
||||
$ coder secret list
|
||||
|
||||
- Show a specific secret:
|
||||
|
||||
$ coder secret list openai-key
|
||||
|
||||
- Delete a secret:
|
||||
|
||||
$ coder secret delete openai-key
|
||||
|
||||
SUBCOMMANDS:
|
||||
create Create a secret
|
||||
delete Delete a secret
|
||||
list List secrets, or show one by name
|
||||
update Update a secret
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
+22
@@ -0,0 +1,22 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder secret create [flags] <name>
|
||||
|
||||
Create a secret
|
||||
|
||||
OPTIONS:
|
||||
--description string
|
||||
Set the secret description.
|
||||
|
||||
--inject-env string
|
||||
Inject the secret into workspaces as an environment variable.
|
||||
|
||||
--inject-file string
|
||||
Inject the secret into workspaces as a file.
|
||||
|
||||
--value string
|
||||
Set the secret value. This flag is required.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
+15
@@ -0,0 +1,15 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder secret delete [flags] <name>
|
||||
|
||||
Delete a secret
|
||||
|
||||
Aliases: remove, rm
|
||||
|
||||
OPTIONS:
|
||||
-y, --yes bool
|
||||
Bypass confirmation prompts.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
+18
@@ -0,0 +1,18 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder secret list [flags] [name]
|
||||
|
||||
List secrets, or show one by name
|
||||
|
||||
Aliases: ls
|
||||
|
||||
OPTIONS:
|
||||
-c, --column [name|updated|env|file|description] (default: name,updated,env,file,description)
|
||||
Columns to display in table output.
|
||||
|
||||
-o, --output table|json (default: table)
|
||||
Output format.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder secret update [flags] <name>
|
||||
|
||||
Update a secret
|
||||
|
||||
OPTIONS:
|
||||
--description string
|
||||
Update the secret description. Pass an empty string to clear it.
|
||||
|
||||
--inject-env string
|
||||
Update the environment variable injection target. Pass an empty string
|
||||
to clear it.
|
||||
|
||||
--inject-file string
|
||||
Update the file injection target. Pass an empty string to clear it.
|
||||
|
||||
--value string
|
||||
Update the secret value.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
@@ -77,8 +77,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
|
||||
var dbLevel database.LogLevel
|
||||
switch logEntry.Level {
|
||||
|
||||
@@ -139,6 +139,59 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("SanitizesOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
now := dbtime.Now()
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
rawOutput := "before\x00middle\xc3\x28after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
|
||||
req := &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(now),
|
||||
Level: agentproto.Log_WARN,
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
|
||||
AgentID: agent.ID,
|
||||
LogSourceID: logSource.ID,
|
||||
CreatedAt: now,
|
||||
Output: []string{sanitizedOutput},
|
||||
Level: []database.LogLevel{database.LogLevelWarn},
|
||||
OutputLength: expectedOutputLength,
|
||||
}).Return([]database.WorkspaceAgentLog{
|
||||
{
|
||||
AgentID: agent.ID,
|
||||
CreatedAt: now,
|
||||
ID: 1,
|
||||
Output: sanitizedOutput,
|
||||
Level: database.LogLevelWarn,
|
||||
LogSourceID: logSource.ID,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
})
|
||||
|
||||
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
|
||||
// An ID is only given in the request when it is a terraform-defined devcontainer
|
||||
// that has attached resources. These subagents are pre-provisioned by terraform
|
||||
// (the agent record already exists), so we update configurable fields like
|
||||
// display_apps rather than creating a new agent.
|
||||
// display_apps and directory rather than creating a new agent.
|
||||
if req.Id != nil {
|
||||
id, err := uuid.FromBytes(req.Id)
|
||||
if err != nil {
|
||||
@@ -97,6 +97,16 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
|
||||
return nil, xerrors.Errorf("update workspace agent display apps: %w", err)
|
||||
}
|
||||
|
||||
if req.Directory != "" {
|
||||
if err := a.Database.UpdateWorkspaceAgentDirectoryByID(ctx, database.UpdateWorkspaceAgentDirectoryByIDParams{
|
||||
ID: id,
|
||||
Directory: req.Directory,
|
||||
UpdatedAt: createdAt,
|
||||
}); err != nil {
|
||||
return nil, xerrors.Errorf("update workspace agent directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &agentproto.CreateSubAgentResponse{
|
||||
Agent: &agentproto.SubAgent{
|
||||
Name: subAgent.Name,
|
||||
|
||||
@@ -1267,11 +1267,11 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
agentID, err := uuid.FromBytes(resp.Agent.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// And: The database agent's other fields are unchanged.
|
||||
// And: The database agent's name, architecture, and OS are unchanged.
|
||||
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, baseChildAgent.Name, updatedAgent.Name)
|
||||
require.Equal(t, baseChildAgent.Directory, updatedAgent.Directory)
|
||||
require.Equal(t, "/different/path", updatedAgent.Directory)
|
||||
require.Equal(t, baseChildAgent.Architecture, updatedAgent.Architecture)
|
||||
require.Equal(t, baseChildAgent.OperatingSystem, updatedAgent.OperatingSystem)
|
||||
|
||||
@@ -1280,6 +1280,42 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
require.Equal(t, database.DisplayAppWebTerminal, updatedAgent.DisplayApps[0])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OK_DirectoryUpdated",
|
||||
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
|
||||
// Given: An existing child agent with a stale host-side
|
||||
// directory (as set by the provisioner at build time).
|
||||
childAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ParentID: uuid.NullUUID{Valid: true, UUID: agent.ID},
|
||||
ResourceID: agent.ResourceID,
|
||||
Name: baseChildAgent.Name,
|
||||
Directory: "/home/coder/project",
|
||||
Architecture: baseChildAgent.Architecture,
|
||||
OperatingSystem: baseChildAgent.OperatingSystem,
|
||||
DisplayApps: baseChildAgent.DisplayApps,
|
||||
})
|
||||
|
||||
// When: Agent injection sends the correct
|
||||
// container-internal path.
|
||||
return &proto.CreateSubAgentRequest{
|
||||
Id: childAgent.ID[:],
|
||||
Directory: "/workspaces/project",
|
||||
DisplayApps: []proto.CreateSubAgentRequest_DisplayApp{
|
||||
proto.CreateSubAgentRequest_WEB_TERMINAL,
|
||||
},
|
||||
}
|
||||
},
|
||||
check: func(t *testing.T, ctx context.Context, db database.Store, resp *proto.CreateSubAgentResponse, agent database.WorkspaceAgent) {
|
||||
agentID, err := uuid.FromBytes(resp.Agent.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: Directory is updated to the container-internal
|
||||
// path.
|
||||
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "/workspaces/project", updatedAgent.Directory)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error/MalformedID",
|
||||
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
|
||||
|
||||
Generated
+281
@@ -9514,6 +9514,212 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": [
|
||||
@@ -13239,6 +13445,12 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -15142,6 +15354,26 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -18917,6 +19149,9 @@ const docTemplate = `{
|
||||
"template_version_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_build_transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -21271,6 +21506,23 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21726,6 +21978,35 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
Generated
+259
@@ -8431,6 +8431,190 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": ["application/json"],
|
||||
@@ -11809,6 +11993,12 @@
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -13643,6 +13833,26 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -17299,6 +17509,9 @@
|
||||
"template_version_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_build_transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -19545,6 +19758,23 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19975,6 +20205,35 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": ["active", "dormant", "suspended"],
|
||||
|
||||
@@ -1608,6 +1608,15 @@ func New(options *Options) *API {
|
||||
|
||||
r.Get("/gitsshkey", api.gitSSHKey)
|
||||
r.Put("/gitsshkey", api.regenerateGitSSHKey)
|
||||
r.Route("/secrets", func(r chi.Router) {
|
||||
r.Post("/", api.postUserSecret)
|
||||
r.Get("/", api.getUserSecrets)
|
||||
r.Route("/{name}", func(r chi.Router) {
|
||||
r.Get("/", api.getUserSecret)
|
||||
r.Patch("/", api.patchUserSecret)
|
||||
r.Delete("/", api.deleteUserSecret)
|
||||
})
|
||||
})
|
||||
r.Route("/notifications", func(r chi.Router) {
|
||||
r.Route("/preferences", func(r chi.Router) {
|
||||
r.Get("/", api.userNotificationPreferences)
|
||||
@@ -1653,6 +1662,10 @@ func New(options *Options) *API {
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
r.Post("/log-source", api.workspaceAgentPostLogSource)
|
||||
r.Get("/reinit", api.workspaceAgentReinit)
|
||||
r.Route("/experimental", func(r chi.Router) {
|
||||
r.Post("/chat-context", api.workspaceAgentAddChatContext)
|
||||
r.Delete("/chat-context", api.workspaceAgentClearChatContext)
|
||||
})
|
||||
r.Route("/tasks/{task}", func(r chi.Router) {
|
||||
r.Post("/log-snapshot", api.postWorkspaceAgentTaskLogSnapshot)
|
||||
})
|
||||
|
||||
@@ -147,6 +147,10 @@ func parseSwaggerComment(commentGroup *ast.CommentGroup) SwaggerComment {
|
||||
return c
|
||||
}
|
||||
|
||||
func isExperimentalEndpoint(route string) bool {
|
||||
return strings.HasPrefix(route, "/workspaceagents/me/experimental/")
|
||||
}
|
||||
|
||||
func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments []SwaggerComment) {
|
||||
assertUniqueRoutes(t, swaggerComments)
|
||||
assertSingleAnnotations(t, swaggerComments)
|
||||
@@ -165,6 +169,9 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [
|
||||
if strings.HasSuffix(route, "/*") {
|
||||
return
|
||||
}
|
||||
if isExperimentalEndpoint(route) {
|
||||
return
|
||||
}
|
||||
|
||||
c := findSwaggerCommentByMethodAndRoute(swaggerComments, method, route)
|
||||
assert.NotNil(t, c, "Missing @Router annotation")
|
||||
|
||||
@@ -538,6 +538,12 @@ func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator,
|
||||
switch {
|
||||
case workspaceAgent.Status != codersdk.WorkspaceAgentConnected && workspaceAgent.LifecycleState == codersdk.WorkspaceAgentLifecycleOff:
|
||||
workspaceAgent.Health.Reason = "agent is not running"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentConnecting:
|
||||
// Note: the case above catches connecting+off as "not running".
|
||||
// This case handles connecting agents with a non-off lifecycle
|
||||
// (e.g. "created" or "starting"), where the agent binary has
|
||||
// not yet established a connection to coderd.
|
||||
workspaceAgent.Health.Reason = "agent has not yet connected"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentTimeout:
|
||||
workspaceAgent.Health.Reason = "agent is taking too long to connect"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentDisconnected:
|
||||
@@ -1234,6 +1240,8 @@ func buildAIBridgeThread(
|
||||
if rootIntc != nil {
|
||||
thread.Model = rootIntc.Model
|
||||
thread.Provider = rootIntc.Provider
|
||||
thread.CredentialKind = string(rootIntc.CredentialKind)
|
||||
thread.CredentialHint = rootIntc.CredentialHint
|
||||
// Get first user prompt from root interception.
|
||||
// A thread can only have one prompt, by definition, since we currently
|
||||
// only store the last prompt observed in an interception.
|
||||
@@ -1715,3 +1723,41 @@ func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// UserSecret converts a database ListUserSecretsRow (metadata only,
|
||||
// no value) to an SDK UserSecret.
|
||||
func UserSecret(secret database.ListUserSecretsRow) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecretFromFull converts a full database UserSecret row to an
|
||||
// SDK UserSecret, omitting the value and encryption key ID.
|
||||
func UserSecretFromFull(secret database.UserSecret) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecrets converts a slice of database ListUserSecretsRow to
|
||||
// SDK UserSecret values.
|
||||
func UserSecrets(secrets []database.ListUserSecretsRow) []codersdk.UserSecret {
|
||||
result := make([]codersdk.UserSecret, 0, len(secrets))
|
||||
for _, s := range secrets {
|
||||
result = append(result, UserSecret(s))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1708,6 +1708,17 @@ func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -2169,10 +2180,10 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
@@ -2413,6 +2424,10 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
return q.db.GetActiveAISeatCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetActiveChatsByAgentID)(ctx, agentID)
|
||||
}
|
||||
|
||||
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -3386,11 +3401,11 @@ func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRI
|
||||
return q.db.GetPRInsightsPerModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
func (q *querier) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsRecentPRs(ctx, arg)
|
||||
return q.db.GetPRInsightsPullRequests(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
@@ -5728,6 +5743,17 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
|
||||
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
@@ -6757,6 +6783,19 @@ func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg da
|
||||
return q.db.UpdateWorkspaceAgentConnectionByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdateAgent, workspace); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return q.db.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
|
||||
@@ -478,6 +478,24 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetActiveChatsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
agentID := uuid.New()
|
||||
dbm.EXPECT().GetActiveChatsByAgentID(gomock.Any(), agentID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(agentID).Asserts(chat, policy.ActionRead).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("SoftDeleteContextFileMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteContextFileMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("ClearChatMessageProviderResponseIDsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ClearChatMessageProviderResponseIDsByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerChatParams{
|
||||
OwnerID: uuid.New(),
|
||||
@@ -2243,9 +2261,9 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsRecentPRsParams{}
|
||||
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
|
||||
s.Run("GetPRInsightsPullRequests", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsPullRequestsParams{}
|
||||
dbm.EXPECT().GetPRInsightsPullRequests(gomock.Any(), arg).Return([]database.GetPRInsightsPullRequestsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
@@ -2917,6 +2935,17 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("UpdateWorkspaceAgentDirectoryByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
arg := database.UpdateWorkspaceAgentDirectoryByIDParams{
|
||||
ID: agt.ID,
|
||||
Directory: "/workspaces/project",
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateWorkspaceAgentDirectoryByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionUpdateAgent).Returns()
|
||||
}))
|
||||
s.Run("UpdateWorkspaceAgentDisplayAppsByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
@@ -5413,10 +5442,10 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
Returns()
|
||||
Returns(int64(1))
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -280,6 +280,14 @@ func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Conte
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("ClearChatMessageProviderResponseIDsByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ClearChatMessageProviderResponseIDsByChatID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
|
||||
@@ -728,12 +736,12 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -968,6 +976,14 @@ func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActiveChatsByAgentID(ctx, agentID)
|
||||
m.queryLatencies.WithLabelValues("GetActiveChatsByAgentID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveChatsByAgentID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
|
||||
@@ -1976,11 +1992,11 @@ func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg databa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
func (m queryMetricsStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
|
||||
r0, r1 := m.s.GetPRInsightsPullRequests(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsPullRequests").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPullRequests").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4104,6 +4120,14 @@ func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, ar
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteContextFileMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteContextFileMessages").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
|
||||
@@ -4816,6 +4840,14 @@ func (m queryMetricsStore) UpdateWorkspaceAgentConnectionByID(ctx context.Contex
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateWorkspaceAgentDirectoryByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateWorkspaceAgentDirectoryByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateWorkspaceAgentDisplayAppsByID(ctx, arg)
|
||||
|
||||
@@ -363,6 +363,20 @@ func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID mocks base method.
|
||||
func (m *MockStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClearChatMessageProviderResponseIDsByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID indicates an expected call of ClearChatMessageProviderResponseIDsByChatID.
|
||||
func (mr *MockStoreMockRecorder) ClearChatMessageProviderResponseIDsByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearChatMessageProviderResponseIDsByChatID", reflect.TypeOf((*MockStore)(nil).ClearChatMessageProviderResponseIDsByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1230,11 +1244,12 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
@@ -1667,6 +1682,21 @@ func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID mocks base method.
|
||||
func (m *MockStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveChatsByAgentID", ctx, agentID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID indicates an expected call of GetActiveChatsByAgentID.
|
||||
func (mr *MockStoreMockRecorder) GetActiveChatsByAgentID(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveChatsByAgentID", reflect.TypeOf((*MockStore)(nil).GetActiveChatsByAgentID), ctx, agentID)
|
||||
}
|
||||
|
||||
// GetActivePresetPrebuildSchedules mocks base method.
|
||||
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3662,19 +3692,19 @@ func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs mocks base method.
|
||||
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
// GetPRInsightsPullRequests mocks base method.
|
||||
func (m *MockStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsPullRequests", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsPullRequestsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
|
||||
// GetPRInsightsPullRequests indicates an expected call of GetPRInsightsPullRequests.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsPullRequests(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPullRequests", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPullRequests), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary mocks base method.
|
||||
@@ -7780,6 +7810,20 @@ func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages mocks base method.
|
||||
func (m *MockStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteContextFileMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages indicates an expected call of SoftDeleteContextFileMessages.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// TryAcquireLock mocks base method.
|
||||
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9076,6 +9120,20 @@ func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentConnectionByID(ctx, arg any
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentConnectionByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentConnectionByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDirectoryByID mocks base method.
|
||||
func (m *MockStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateWorkspaceAgentDirectoryByID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDirectoryByID indicates an expected call of UpdateWorkspaceAgentDirectoryByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentDirectoryByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentDirectoryByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentDirectoryByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDisplayAppsByID mocks base method.
|
||||
func (m *MockStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+2
-2
@@ -3783,14 +3783,14 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chats_labels ON chats USING gin (labels);
|
||||
|
||||
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner_updated_id ON chats USING btree (owner_id, updated_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_chats_parent_chat_id ON chats USING btree (parent_chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS idx_chats_agent_id;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_chats_agent_id ON chats(agent_id) WHERE agent_id IS NOT NULL;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_chats_owner_updated_id ON chats (owner_id, updated_at DESC, id DESC);
|
||||
@@ -0,0 +1,5 @@
|
||||
-- The GetChats ORDER BY changed from (updated_at, id) DESC to a 4-column
|
||||
-- expression sort (pinned-first flag, negated pin_order, updated_at, id).
|
||||
-- This index was purpose-built for the old sort and no longer provides
|
||||
-- read benefit. The simpler idx_chats_owner covers the owner_id filter.
|
||||
DROP INDEX IF EXISTS idx_chats_owner_updated_id;
|
||||
@@ -76,6 +76,7 @@ type sqlcQuerier interface {
|
||||
CleanTailnetLostPeers(ctx context.Context) error
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
@@ -168,7 +169,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (int64, error)
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -215,6 +216,7 @@ type sqlcQuerier interface {
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActiveAISeatCount(ctx context.Context) (int64, error)
|
||||
GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
@@ -416,11 +418,12 @@ type sqlcQuerier interface {
|
||||
// per PR for state/additions/deletions/model (model comes from the
|
||||
// most recent chat).
|
||||
GetPRInsightsPerModel(ctx context.Context, arg GetPRInsightsPerModelParams) ([]GetPRInsightsPerModelRow, error)
|
||||
// Returns individual PR rows with cost for the recent PRs table.
|
||||
// Returns all individual PR rows with cost for the selected time range.
|
||||
// Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
|
||||
// direct children (that lack their own PR), and deduped picks one row
|
||||
// per PR for metadata.
|
||||
GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error)
|
||||
// per PR for metadata. A safety-cap LIMIT guards against unexpectedly
|
||||
// large result sets from direct API callers.
|
||||
GetPRInsightsPullRequests(ctx context.Context, arg GetPRInsightsPullRequestsParams) ([]GetPRInsightsPullRequestsRow, error)
|
||||
// PR Insights queries for the /agents analytics dashboard.
|
||||
// These aggregate data from chat_diff_statuses (PR metadata) joined
|
||||
// with chats and chat_messages (cost) to power the PR Insights view.
|
||||
@@ -893,6 +896,7 @@ type sqlcQuerier interface {
|
||||
SelectUsageEventsForPublishing(ctx context.Context, now time.Time) ([]UsageEvent, error)
|
||||
SoftDeleteChatMessageByID(ctx context.Context, id int64) error
|
||||
SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error
|
||||
SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error
|
||||
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
||||
//
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
@@ -1008,6 +1012,7 @@ type sqlcQuerier interface {
|
||||
UpdateWorkspace(ctx context.Context, arg UpdateWorkspaceParams) (WorkspaceTable, error)
|
||||
UpdateWorkspaceACLByID(ctx context.Context, arg UpdateWorkspaceACLByIDParams) error
|
||||
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
|
||||
UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg UpdateWorkspaceAgentDirectoryByIDParams) error
|
||||
UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg UpdateWorkspaceAgentDisplayAppsByIDParams) error
|
||||
UpdateWorkspaceAgentLifecycleStateByID(ctx context.Context, arg UpdateWorkspaceAgentLifecycleStateByIDParams) error
|
||||
UpdateWorkspaceAgentLogOverflowByID(ctx context.Context, arg UpdateWorkspaceAgentLogOverflowByIDParams) error
|
||||
|
||||
@@ -7376,7 +7376,7 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
||||
|
||||
// 6. DELETE
|
||||
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
_, err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
})
|
||||
@@ -10408,11 +10408,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10442,11 +10441,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
||||
|
||||
// RecentPRs ordered by created_at DESC: chatB is newer.
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10491,11 +10489,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10533,11 +10530,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(9_000_000), summary.TotalCostMicros)
|
||||
|
||||
// RecentPRs should return 1 row with the full tree cost.
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10575,11 +10571,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10621,11 +10616,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(17_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10658,11 +10652,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(10_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10695,11 +10688,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(15_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10724,11 +10716,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(0), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10767,11 +10758,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
require.Len(t, byModel, 1)
|
||||
assert.Equal(t, modelName, byModel[0].DisplayName)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10803,6 +10793,30 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
||||
assert.Equal(t, int64(5_000_000), summary.MergedCostMicros)
|
||||
})
|
||||
|
||||
t.Run("AllPRsReturnedWithSafetyCap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID := setupChatInfra(t)
|
||||
|
||||
// Create 25 distinct PRs — more than the old LIMIT 20 — and
|
||||
// verify all are returned.
|
||||
const prCount = 25
|
||||
for i := range prCount {
|
||||
chat := createChat(t, store, userID, mcID, fmt.Sprintf("chat-%d", i))
|
||||
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
|
||||
linkPR(t, store, chat.ID,
|
||||
fmt.Sprintf("https://github.com/org/repo/pull/%d", 100+i),
|
||||
"merged", fmt.Sprintf("fix: pr-%d", i), 10, 2, 1)
|
||||
}
|
||||
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, recent, prCount, "all PRs within the date range should be returned")
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatPinOrderQueries(t *testing.T) {
|
||||
|
||||
+164
-53
@@ -3218,7 +3218,7 @@ func (q *sqlQuerier) GetPRInsightsPerModel(ctx context.Context, arg GetPRInsight
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getPRInsightsRecentPRs = `-- name: GetPRInsightsRecentPRs :many
|
||||
const getPRInsightsPullRequests = `-- name: GetPRInsightsPullRequests :many
|
||||
WITH pr_costs AS (
|
||||
SELECT
|
||||
prc.pr_key,
|
||||
@@ -3238,9 +3238,9 @@ WITH pr_costs AS (
|
||||
AND cds2.pull_request_state IS NOT NULL
|
||||
))
|
||||
WHERE cds.pull_request_state IS NOT NULL
|
||||
AND c.created_at >= $2::timestamptz
|
||||
AND c.created_at < $3::timestamptz
|
||||
AND ($4::uuid IS NULL OR c.owner_id = $4::uuid)
|
||||
AND c.created_at >= $1::timestamptz
|
||||
AND c.created_at < $2::timestamptz
|
||||
AND ($3::uuid IS NULL OR c.owner_id = $3::uuid)
|
||||
) prc
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT COALESCE(SUM(cm.total_cost_micros), 0) AS cost_micros
|
||||
@@ -3275,9 +3275,9 @@ deduped AS (
|
||||
JOIN chats c ON c.id = cds.chat_id
|
||||
LEFT JOIN chat_model_configs cmc ON cmc.id = c.last_model_config_id
|
||||
WHERE cds.pull_request_state IS NOT NULL
|
||||
AND c.created_at >= $2::timestamptz
|
||||
AND c.created_at < $3::timestamptz
|
||||
AND ($4::uuid IS NULL OR c.owner_id = $4::uuid)
|
||||
AND c.created_at >= $1::timestamptz
|
||||
AND c.created_at < $2::timestamptz
|
||||
AND ($3::uuid IS NULL OR c.owner_id = $3::uuid)
|
||||
ORDER BY COALESCE(NULLIF(cds.url, ''), c.id::text), c.created_at DESC, c.id DESC
|
||||
)
|
||||
SELECT chat_id, pr_title, pr_url, pr_number, state, draft, additions, deletions, changed_files, commits, approved, changes_requested, reviewer_count, author_login, author_avatar_url, base_branch, model_display_name, cost_micros, created_at FROM (
|
||||
@@ -3305,17 +3305,16 @@ SELECT chat_id, pr_title, pr_url, pr_number, state, draft, additions, deletions,
|
||||
JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
) sub
|
||||
ORDER BY sub.created_at DESC
|
||||
LIMIT $1::int
|
||||
LIMIT 500
|
||||
`
|
||||
|
||||
type GetPRInsightsRecentPRsParams struct {
|
||||
LimitVal int32 `db:"limit_val" json:"limit_val"`
|
||||
type GetPRInsightsPullRequestsParams struct {
|
||||
StartDate time.Time `db:"start_date" json:"start_date"`
|
||||
EndDate time.Time `db:"end_date" json:"end_date"`
|
||||
OwnerID uuid.NullUUID `db:"owner_id" json:"owner_id"`
|
||||
}
|
||||
|
||||
type GetPRInsightsRecentPRsRow struct {
|
||||
type GetPRInsightsPullRequestsRow struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
PrTitle string `db:"pr_title" json:"pr_title"`
|
||||
PrUrl sql.NullString `db:"pr_url" json:"pr_url"`
|
||||
@@ -3337,24 +3336,20 @@ type GetPRInsightsRecentPRsRow struct {
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
// Returns individual PR rows with cost for the recent PRs table.
|
||||
// Returns all individual PR rows with cost for the selected time range.
|
||||
// Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
|
||||
// direct children (that lack their own PR), and deduped picks one row
|
||||
// per PR for metadata.
|
||||
func (q *sqlQuerier) GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getPRInsightsRecentPRs,
|
||||
arg.LimitVal,
|
||||
arg.StartDate,
|
||||
arg.EndDate,
|
||||
arg.OwnerID,
|
||||
)
|
||||
// per PR for metadata. A safety-cap LIMIT guards against unexpectedly
|
||||
// large result sets from direct API callers.
|
||||
func (q *sqlQuerier) GetPRInsightsPullRequests(ctx context.Context, arg GetPRInsightsPullRequestsParams) ([]GetPRInsightsPullRequestsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getPRInsightsPullRequests, arg.StartDate, arg.EndDate, arg.OwnerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetPRInsightsRecentPRsRow
|
||||
var items []GetPRInsightsPullRequestsRow
|
||||
for rows.Next() {
|
||||
var i GetPRInsightsRecentPRsRow
|
||||
var i GetPRInsightsPullRequestsRow
|
||||
if err := rows.Scan(
|
||||
&i.ChatID,
|
||||
&i.PrTitle,
|
||||
@@ -4505,6 +4500,19 @@ func (q *sqlQuerier) BackoffChatDiffStatus(ctx context.Context, arg BackoffChatD
|
||||
return err
|
||||
}
|
||||
|
||||
const clearChatMessageProviderResponseIDsByChatID = `-- name: ClearChatMessageProviderResponseIDsByChatID :exec
|
||||
UPDATE chat_messages
|
||||
SET provider_response_id = NULL
|
||||
WHERE chat_id = $1::uuid
|
||||
AND deleted = false
|
||||
AND provider_response_id IS NOT NULL
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, clearChatMessageProviderResponseIDsByChatID, chatID)
|
||||
return err
|
||||
}
|
||||
|
||||
const countEnabledModelsWithoutPricing = `-- name: CountEnabledModelsWithoutPricing :one
|
||||
SELECT COUNT(*)::bigint AS count
|
||||
FROM chat_model_configs
|
||||
@@ -4603,6 +4611,66 @@ func (q *sqlQuerier) DeleteOldChats(ctx context.Context, arg DeleteOldChatsParam
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
const getActiveChatsByAgentID = `-- name: GetActiveChatsByAgentID :many
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
FROM chats
|
||||
WHERE agent_id = $1::uuid
|
||||
AND archived = false
|
||||
-- Active statuses only: waiting, pending, running, paused,
|
||||
-- requires_action.
|
||||
-- Excludes completed and error (terminal states).
|
||||
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
|
||||
ORDER BY updated_at DESC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getActiveChatsByAgentID, agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Chat
|
||||
for rows.Next() {
|
||||
var i Chat
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
&i.PinOrder,
|
||||
&i.LastReadMessageID,
|
||||
&i.LastInjectedContext,
|
||||
&i.DynamicTools,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getChatByID = `-- name: GetChatByID :one
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
@@ -5750,20 +5818,18 @@ WHERE
|
||||
ELSE chats.archived = $2 :: boolean
|
||||
END
|
||||
AND CASE
|
||||
-- This allows using the last element on a page as effectively a cursor.
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
-- Cursor pagination: the last element on a page acts as the cursor.
|
||||
-- The 4-tuple matches the ORDER BY below. All columns sort DESC
|
||||
-- (pin_order is negated so lower values sort first in DESC order),
|
||||
-- which lets us use a single tuple < comparison.
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the updated_at field, so select all
|
||||
-- rows before the cursor.
|
||||
(updated_at, id) < (
|
||||
(CASE WHEN pin_order > 0 THEN 1 ELSE 0 END, -pin_order, updated_at, id) < (
|
||||
SELECT
|
||||
updated_at, id
|
||||
CASE WHEN c2.pin_order > 0 THEN 1 ELSE 0 END, -c2.pin_order, c2.updated_at, c2.id
|
||||
FROM
|
||||
chats
|
||||
chats c2
|
||||
WHERE
|
||||
id = $3
|
||||
c2.id = $3
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
@@ -5775,9 +5841,15 @@ WHERE
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedChats
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(updated_at, id) DESC OFFSET $5
|
||||
-- Pinned chats (pin_order > 0) sort before unpinned ones. Within
|
||||
-- pinned chats, lower pin_order values come first. The negation
|
||||
-- trick (-pin_order) keeps all sort columns DESC so the cursor
|
||||
-- tuple < comparison works with uniform direction.
|
||||
CASE WHEN pin_order > 0 THEN 1 ELSE 0 END DESC,
|
||||
-pin_order DESC,
|
||||
updated_at DESC,
|
||||
id DESC
|
||||
OFFSET $5
|
||||
LIMIT
|
||||
-- The chat list is unbounded and expected to grow large.
|
||||
-- Default to 50 to prevent accidental excessively large queries.
|
||||
@@ -6706,6 +6778,18 @@ func (q *sqlQuerier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg Soft
|
||||
return err
|
||||
}
|
||||
|
||||
const softDeleteContextFileMessages = `-- name: SoftDeleteContextFileMessages :exec
|
||||
UPDATE chat_messages SET deleted = true
|
||||
WHERE chat_id = $1::uuid
|
||||
AND deleted = false
|
||||
AND content::jsonb @> '[{"type": "context-file"}]'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, softDeleteContextFileMessages, chatID)
|
||||
return err
|
||||
}
|
||||
|
||||
const unarchiveChatByID = `-- name: UnarchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats SET
|
||||
@@ -17434,7 +17518,8 @@ SELECT
|
||||
w.id AS workspace_id,
|
||||
COALESCE(w.name, '') AS workspace_name,
|
||||
-- Include the name of the provisioner_daemon associated to the job
|
||||
COALESCE(pd.name, '') AS worker_name
|
||||
COALESCE(pd.name, '') AS worker_name,
|
||||
wb.transition as workspace_build_transition
|
||||
FROM
|
||||
provisioner_jobs pj
|
||||
LEFT JOIN
|
||||
@@ -17479,7 +17564,8 @@ GROUP BY
|
||||
t.icon,
|
||||
w.id,
|
||||
w.name,
|
||||
pd.name
|
||||
pd.name,
|
||||
wb.transition
|
||||
ORDER BY
|
||||
pj.created_at DESC
|
||||
LIMIT
|
||||
@@ -17496,18 +17582,19 @@ type GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerPar
|
||||
}
|
||||
|
||||
type GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow struct {
|
||||
ProvisionerJob ProvisionerJob `db:"provisioner_job" json:"provisioner_job"`
|
||||
QueuePosition int64 `db:"queue_position" json:"queue_position"`
|
||||
QueueSize int64 `db:"queue_size" json:"queue_size"`
|
||||
AvailableWorkers []uuid.UUID `db:"available_workers" json:"available_workers"`
|
||||
TemplateVersionName string `db:"template_version_name" json:"template_version_name"`
|
||||
TemplateID uuid.NullUUID `db:"template_id" json:"template_id"`
|
||||
TemplateName string `db:"template_name" json:"template_name"`
|
||||
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
|
||||
TemplateIcon string `db:"template_icon" json:"template_icon"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
|
||||
WorkerName string `db:"worker_name" json:"worker_name"`
|
||||
ProvisionerJob ProvisionerJob `db:"provisioner_job" json:"provisioner_job"`
|
||||
QueuePosition int64 `db:"queue_position" json:"queue_position"`
|
||||
QueueSize int64 `db:"queue_size" json:"queue_size"`
|
||||
AvailableWorkers []uuid.UUID `db:"available_workers" json:"available_workers"`
|
||||
TemplateVersionName string `db:"template_version_name" json:"template_version_name"`
|
||||
TemplateID uuid.NullUUID `db:"template_id" json:"template_id"`
|
||||
TemplateName string `db:"template_name" json:"template_name"`
|
||||
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
|
||||
TemplateIcon string `db:"template_icon" json:"template_icon"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
|
||||
WorkerName string `db:"worker_name" json:"worker_name"`
|
||||
WorkspaceBuildTransition NullWorkspaceTransition `db:"workspace_build_transition" json:"workspace_build_transition"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error) {
|
||||
@@ -17559,6 +17646,7 @@ func (q *sqlQuerier) GetProvisionerJobsByOrganizationAndStatusWithQueuePositionA
|
||||
&i.WorkspaceID,
|
||||
&i.WorkspaceName,
|
||||
&i.WorkerName,
|
||||
&i.WorkspaceBuildTransition,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -23042,7 +23130,7 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :execrows
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2
|
||||
`
|
||||
@@ -23052,9 +23140,12 @@ type DeleteUserSecretByUserIDAndNameParams struct {
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
|
||||
return err
|
||||
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
result, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
const getUserSecretByUserIDAndName = `-- name: GetUserSecretByUserIDAndName :one
|
||||
@@ -26728,6 +26819,26 @@ func (q *sqlQuerier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg
|
||||
return err
|
||||
}
|
||||
|
||||
const updateWorkspaceAgentDirectoryByID = `-- name: UpdateWorkspaceAgentDirectoryByID :exec
|
||||
UPDATE
|
||||
workspace_agents
|
||||
SET
|
||||
directory = $2, updated_at = $3
|
||||
WHERE
|
||||
id = $1
|
||||
`
|
||||
|
||||
type UpdateWorkspaceAgentDirectoryByIDParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Directory string `db:"directory" json:"directory"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateWorkspaceAgentDirectoryByID, arg.ID, arg.Directory, arg.UpdatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateWorkspaceAgentDisplayAppsByID = `-- name: UpdateWorkspaceAgentDisplayAppsByID :exec
|
||||
UPDATE
|
||||
workspace_agents
|
||||
|
||||
@@ -173,11 +173,12 @@ JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
GROUP BY d.model_config_id, d.display_name, d.model, d.provider
|
||||
ORDER BY total_prs DESC;
|
||||
|
||||
-- name: GetPRInsightsRecentPRs :many
|
||||
-- Returns individual PR rows with cost for the recent PRs table.
|
||||
-- name: GetPRInsightsPullRequests :many
|
||||
-- Returns all individual PR rows with cost for the selected time range.
|
||||
-- Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
|
||||
-- direct children (that lack their own PR), and deduped picks one row
|
||||
-- per PR for metadata.
|
||||
-- per PR for metadata. A safety-cap LIMIT guards against unexpectedly
|
||||
-- large result sets from direct API callers.
|
||||
WITH pr_costs AS (
|
||||
SELECT
|
||||
prc.pr_key,
|
||||
@@ -264,4 +265,4 @@ SELECT * FROM (
|
||||
JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
) sub
|
||||
ORDER BY sub.created_at DESC
|
||||
LIMIT @limit_val::int;
|
||||
LIMIT 500;
|
||||
|
||||
@@ -353,20 +353,18 @@ WHERE
|
||||
ELSE chats.archived = sqlc.narg('archived') :: boolean
|
||||
END
|
||||
AND CASE
|
||||
-- This allows using the last element on a page as effectively a cursor.
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
-- Cursor pagination: the last element on a page acts as the cursor.
|
||||
-- The 4-tuple matches the ORDER BY below. All columns sort DESC
|
||||
-- (pin_order is negated so lower values sort first in DESC order),
|
||||
-- which lets us use a single tuple < comparison.
|
||||
WHEN @after_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the updated_at field, so select all
|
||||
-- rows before the cursor.
|
||||
(updated_at, id) < (
|
||||
(CASE WHEN pin_order > 0 THEN 1 ELSE 0 END, -pin_order, updated_at, id) < (
|
||||
SELECT
|
||||
updated_at, id
|
||||
CASE WHEN c2.pin_order > 0 THEN 1 ELSE 0 END, -c2.pin_order, c2.updated_at, c2.id
|
||||
FROM
|
||||
chats
|
||||
chats c2
|
||||
WHERE
|
||||
id = @after_id
|
||||
c2.id = @after_id
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
@@ -378,9 +376,15 @@ WHERE
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedChats
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(updated_at, id) DESC OFFSET @offset_opt
|
||||
-- Pinned chats (pin_order > 0) sort before unpinned ones. Within
|
||||
-- pinned chats, lower pin_order values come first. The negation
|
||||
-- trick (-pin_order) keeps all sort columns DESC so the cursor
|
||||
-- tuple < comparison works with uniform direction.
|
||||
CASE WHEN pin_order > 0 THEN 1 ELSE 0 END DESC,
|
||||
-pin_order DESC,
|
||||
updated_at DESC,
|
||||
id DESC
|
||||
OFFSET @offset_opt
|
||||
LIMIT
|
||||
-- The chat list is unbounded and expected to grow large.
|
||||
-- Default to 50 to prevent accidental excessively large queries.
|
||||
@@ -1293,3 +1297,26 @@ GROUP BY cm.chat_id;
|
||||
SELECT id, provider, model, context_limit, enabled, is_default
|
||||
FROM chat_model_configs
|
||||
WHERE deleted = false;
|
||||
-- name: GetActiveChatsByAgentID :many
|
||||
SELECT *
|
||||
FROM chats
|
||||
WHERE agent_id = @agent_id::uuid
|
||||
AND archived = false
|
||||
-- Active statuses only: waiting, pending, running, paused,
|
||||
-- requires_action.
|
||||
-- Excludes completed and error (terminal states).
|
||||
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
|
||||
ORDER BY updated_at DESC;
|
||||
|
||||
-- name: ClearChatMessageProviderResponseIDsByChatID :exec
|
||||
UPDATE chat_messages
|
||||
SET provider_response_id = NULL
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND provider_response_id IS NOT NULL;
|
||||
|
||||
-- name: SoftDeleteContextFileMessages :exec
|
||||
UPDATE chat_messages SET deleted = true
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND content::jsonb @> '[{"type": "context-file"}]';
|
||||
|
||||
@@ -195,7 +195,8 @@ SELECT
|
||||
w.id AS workspace_id,
|
||||
COALESCE(w.name, '') AS workspace_name,
|
||||
-- Include the name of the provisioner_daemon associated to the job
|
||||
COALESCE(pd.name, '') AS worker_name
|
||||
COALESCE(pd.name, '') AS worker_name,
|
||||
wb.transition as workspace_build_transition
|
||||
FROM
|
||||
provisioner_jobs pj
|
||||
LEFT JOIN
|
||||
@@ -240,7 +241,8 @@ GROUP BY
|
||||
t.icon,
|
||||
w.id,
|
||||
w.name,
|
||||
pd.name
|
||||
pd.name,
|
||||
wb.transition
|
||||
ORDER BY
|
||||
pj.created_at DESC
|
||||
LIMIT
|
||||
|
||||
@@ -56,6 +56,6 @@ SET
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
-- name: DeleteUserSecretByUserIDAndName :execrows
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
@@ -190,6 +190,14 @@ SET
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateWorkspaceAgentDirectoryByID :exec
|
||||
UPDATE
|
||||
workspace_agents
|
||||
SET
|
||||
directory = $2, updated_at = $3
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetWorkspaceAgentLogsAfter :many
|
||||
SELECT
|
||||
*
|
||||
|
||||
+70
-77
@@ -137,8 +137,9 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.
|
||||
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
logger := api.Logger.Named("chat_watcher")
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat watch stream.",
|
||||
@@ -146,54 +147,44 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatEvent(
|
||||
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// The encoder is only written from the SubscribeWithErr callback,
|
||||
// which delivers serially per subscription. Do not add a second
|
||||
// write path without introducing synchronization.
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatWatchEvent(
|
||||
func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
|
||||
logger.Error(ctx, "chat watch event subscription error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: payload,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err))
|
||||
if err := encoder.Encode(payload); err != nil {
|
||||
logger.Debug(ctx, "failed to send chat watch event", slog.Error(err))
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Internal error subscribing to chat events.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err))
|
||||
}
|
||||
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.")
|
||||
return
|
||||
}
|
||||
defer cancelSubscribe()
|
||||
|
||||
// Send initial ping to signal the connection is ready.
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypePing,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
|
||||
@@ -1819,9 +1810,9 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
// - pinOrder > 0 && already pinned: reorder (shift
|
||||
// neighbors, clamp to [1, count]).
|
||||
// - pinOrder > 0 && not pinned: append to end. The
|
||||
// requested value is intentionally ignored because
|
||||
// PinChatByID also bumps updated_at to keep the
|
||||
// chat visible in the paginated sidebar.
|
||||
// requested value is intentionally ignored; the
|
||||
// SQL ORDER BY sorts pinned chats first so they
|
||||
// appear on page 1 of the paginated sidebar.
|
||||
var err error
|
||||
errMsg := "Failed to pin chat."
|
||||
switch {
|
||||
@@ -2176,6 +2167,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -2198,7 +2190,22 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
// Subscribe before accepting the WebSocket so that failures
|
||||
// can still be reported as normal HTTP errors.
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
// Subscribe only fails today when the receiver is nil, which
|
||||
// the chatDaemon == nil guard above already catches. This is
|
||||
// defensive against future Subscribe failure modes.
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer cancelSub()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
@@ -2206,41 +2213,30 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err))
|
||||
}
|
||||
// Ensure the WebSocket is closed so senderClosed
|
||||
// completes and the handler can return.
|
||||
<-senderClosed
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// Mark the chat as read when the stream connects and again
|
||||
// when it disconnects so we avoid per-message API calls while
|
||||
// messages are actively streaming.
|
||||
api.markChatAsRead(ctx, chatID)
|
||||
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
|
||||
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: batch,
|
||||
})
|
||||
return encoder.Encode(batch)
|
||||
}
|
||||
|
||||
drainChatStreamBatch := func(
|
||||
@@ -2273,7 +2269,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
end = len(snapshot)
|
||||
}
|
||||
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -2282,8 +2278,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
case firstEvent, ok := <-events:
|
||||
if !ok {
|
||||
return
|
||||
@@ -2293,7 +2287,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chatStreamBatchSize,
|
||||
)
|
||||
if err := sendChatStreamBatch(batch); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if streamClosed {
|
||||
@@ -2308,6 +2302,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon != nil {
|
||||
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
||||
@@ -2321,8 +2316,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if updateErr != nil {
|
||||
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
||||
slog.F("chat_id", chatID), slog.Error(updateErr))
|
||||
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to interrupt chat.",
|
||||
Detail: updateErr.Error(),
|
||||
@@ -5632,7 +5626,7 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
previousSummary database.GetPRInsightsSummaryRow
|
||||
timeSeries []database.GetPRInsightsTimeSeriesRow
|
||||
byModel []database.GetPRInsightsPerModelRow
|
||||
recentPRs []database.GetPRInsightsRecentPRsRow
|
||||
recentPRs []database.GetPRInsightsPullRequestsRow
|
||||
)
|
||||
|
||||
eg, egCtx := errgroup.WithContext(ctx)
|
||||
@@ -5680,11 +5674,10 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
recentPRs, err = api.Database.GetPRInsightsRecentPRs(egCtx, database.GetPRInsightsRecentPRsParams{
|
||||
recentPRs, err = api.Database.GetPRInsightsPullRequests(egCtx, database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: ownerID,
|
||||
LimitVal: 20,
|
||||
})
|
||||
return err
|
||||
})
|
||||
@@ -5794,10 +5787,10 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PRInsightsResponse{
|
||||
Summary: summary,
|
||||
TimeSeries: tsEntries,
|
||||
ByModel: modelEntries,
|
||||
RecentPRs: prEntries,
|
||||
Summary: summary,
|
||||
TimeSeries: tsEntries,
|
||||
ByModel: modelEntries,
|
||||
PullRequests: prEntries,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+199
-98
@@ -876,6 +876,186 @@ func TestListChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, allChats, totalChats)
|
||||
})
|
||||
|
||||
// Test that a pinned chat with an old updated_at appears on page 1.
|
||||
t.Run("PinnedOnFirstPage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create the chat that will later be pinned. It gets the
|
||||
// earliest updated_at because it is inserted first.
|
||||
pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "pinned-chat",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fill page 1 with newer chats so the pinned chat would
|
||||
// normally be pushed off the first page (default limit 50).
|
||||
const fillerCount = 51
|
||||
fillerChats := make([]codersdk.Chat, 0, fillerCount)
|
||||
for i := range fillerCount {
|
||||
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: fmt.Sprintf("filler-%d", i),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, createErr)
|
||||
fillerChats = append(fillerChats, c)
|
||||
}
|
||||
|
||||
// Wait for all chats to reach a terminal status so
|
||||
// updated_at is stable before paginating. A single
|
||||
// polling loop checks every chat per tick to avoid
|
||||
// O(N) separate Eventually loops.
|
||||
allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...)
|
||||
pending := make(map[uuid.UUID]struct{}, len(allCreated))
|
||||
for _, c := range allCreated {
|
||||
pending[c.ID] = struct{}{}
|
||||
}
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: fillerCount + 10},
|
||||
})
|
||||
if listErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, ch := range all {
|
||||
if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning {
|
||||
delete(pending, ch.ID)
|
||||
}
|
||||
}
|
||||
return len(pending) == 0
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Pin the earliest chat.
|
||||
err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch page 1 with default limit (50).
|
||||
page1, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: 50},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The pinned chat must appear on page 1.
|
||||
page1IDs := make(map[uuid.UUID]struct{}, len(page1))
|
||||
for _, c := range page1 {
|
||||
page1IDs[c.ID] = struct{}{}
|
||||
}
|
||||
_, found := page1IDs[pinnedChat.ID]
|
||||
require.True(t, found, "pinned chat should appear on page 1")
|
||||
|
||||
// The pinned chat should be the first item in the list.
|
||||
require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first")
|
||||
})
|
||||
|
||||
// Test cursor pagination with a mix of pinned and unpinned chats.
|
||||
t.Run("CursorWithPins", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create 5 chats: 2 will be pinned, 3 unpinned.
|
||||
const totalChats = 5
|
||||
createdChats := make([]codersdk.Chat, 0, totalChats)
|
||||
for i := range totalChats {
|
||||
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: fmt.Sprintf("cursor-pin-chat-%d", i),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, createErr)
|
||||
createdChats = append(createdChats, c)
|
||||
}
|
||||
|
||||
// Wait for all chats to reach terminal status.
|
||||
// Check each chat by ID rather than fetching the full list.
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
for _, c := range createdChats {
|
||||
ch, err := client.GetChat(ctx, c.ID)
|
||||
require.NoError(t, err, "GetChat should succeed for just-created chat %s", c.ID)
|
||||
if ch.Status == codersdk.ChatStatusPending || ch.Status == codersdk.ChatStatusRunning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Pin the first two chats (oldest updated_at).
|
||||
err := client.UpdateChat(ctx, createdChats[0].ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = client.UpdateChat(ctx, createdChats[1].ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Paginate with limit=2 using cursor (after_id).
|
||||
const pageSize = 2
|
||||
maxPages := totalChats/pageSize + 2
|
||||
var allPaginated []codersdk.Chat
|
||||
var afterID uuid.UUID
|
||||
for range maxPages {
|
||||
opts := &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: pageSize},
|
||||
}
|
||||
if afterID != uuid.Nil {
|
||||
opts.Pagination.AfterID = afterID
|
||||
}
|
||||
page, listErr := client.ListChats(ctx, opts)
|
||||
require.NoError(t, listErr)
|
||||
if len(page) == 0 {
|
||||
break
|
||||
}
|
||||
allPaginated = append(allPaginated, page...)
|
||||
afterID = page[len(page)-1].ID
|
||||
}
|
||||
|
||||
// All chats should appear exactly once.
|
||||
seenIDs := make(map[uuid.UUID]struct{}, len(allPaginated))
|
||||
for _, c := range allPaginated {
|
||||
_, dup := seenIDs[c.ID]
|
||||
require.False(t, dup, "chat %s appeared more than once", c.ID)
|
||||
seenIDs[c.ID] = struct{}{}
|
||||
}
|
||||
require.Len(t, seenIDs, totalChats, "all chats should appear in paginated results")
|
||||
|
||||
// Pinned chats should come before unpinned ones, and
|
||||
// within the pinned group, lower pin_order sorts first.
|
||||
pinnedSeen := false
|
||||
unpinnedSeen := false
|
||||
for _, c := range allPaginated {
|
||||
if c.PinOrder > 0 {
|
||||
require.False(t, unpinnedSeen, "pinned chat %s appeared after unpinned chat", c.ID)
|
||||
pinnedSeen = true
|
||||
} else {
|
||||
unpinnedSeen = true
|
||||
}
|
||||
}
|
||||
require.True(t, pinnedSeen, "at least one pinned chat should exist")
|
||||
|
||||
// Verify within-pinned ordering: pin_order=1 before
|
||||
// pin_order=2 (the -pin_order DESC column).
|
||||
require.Equal(t, createdChats[0].ID, allPaginated[0].ID,
|
||||
"pin_order=1 chat should be first")
|
||||
require.Equal(t, createdChats[1].ID, allPaginated[1].ID,
|
||||
"pin_order=2 chat should be second")
|
||||
})
|
||||
}
|
||||
|
||||
func TestListChatModels(t *testing.T) {
|
||||
@@ -1114,17 +1294,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1136,25 +1305,16 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1174,18 +1334,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Skip the initial ping.
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1198,18 +1346,11 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
var got codersdk.Chat
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
var update watchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &update); readErr != nil {
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil {
|
||||
return false
|
||||
}
|
||||
if update.Type != codersdk.ServerSentEventTypeData {
|
||||
return false
|
||||
}
|
||||
var payload coderdpubsub.ChatEvent
|
||||
if unmarshalErr := json.Unmarshal(update.Data, &payload); unmarshalErr != nil {
|
||||
return false
|
||||
}
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
got = payload.Chat
|
||||
return true
|
||||
@@ -1282,25 +1423,14 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Read the initial ping.
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
// Publish a diff_status_change event via pubsub,
|
||||
// mimicking what PublishDiffStatusChange does after
|
||||
// it reads the diff status from the DB.
|
||||
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindDiffStatusChange,
|
||||
Chat: codersdk.Chat{
|
||||
ID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
@@ -1313,25 +1443,15 @@ func TestWatchChats(t *testing.T) {
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read events until we find the diff_status_change.
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var received codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var received coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
|
||||
if received.Kind != codersdk.ChatWatchEventKindDiffStatusChange ||
|
||||
received.Chat.ID != chat.ID {
|
||||
continue
|
||||
}
|
||||
@@ -1350,7 +1470,6 @@ func TestWatchChats(t *testing.T) {
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1393,31 +1512,13 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
|
||||
collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent {
|
||||
t.Helper()
|
||||
|
||||
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
|
||||
events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3)
|
||||
for len(events) < 3 {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind != expectedKind {
|
||||
continue
|
||||
@@ -1427,7 +1528,7 @@ func TestWatchChats(t *testing.T) {
|
||||
return events
|
||||
}
|
||||
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) {
|
||||
t.Helper()
|
||||
|
||||
require.Len(t, events, 3)
|
||||
@@ -1440,12 +1541,12 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
|
||||
deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted)
|
||||
assertLifecycleEvents(deletedEvents, true)
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
require.NoError(t, err)
|
||||
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
|
||||
createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated)
|
||||
assertLifecycleEvents(createdEvents, false)
|
||||
})
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
package coderd
|
||||
|
||||
// InsertAgentChatTestModelConfig exposes insertAgentChatTestModelConfig for external tests.
|
||||
var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -146,12 +147,35 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
|
||||
cancel := params.redirectURL
|
||||
cancelQuery := params.redirectURL.Query()
|
||||
cancelQuery.Add("error", "access_denied")
|
||||
cancelQuery.Add("error_description", "The resource owner or authorization server denied the request")
|
||||
if params.state != "" {
|
||||
cancelQuery.Add("state", params.state)
|
||||
}
|
||||
cancel.RawQuery = cancelQuery.Encode()
|
||||
|
||||
cancelURI := cancel.String()
|
||||
if err := codersdk.ValidateRedirectURIScheme(cancel); err != nil {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadRequest,
|
||||
HideStatus: false,
|
||||
Title: "Invalid Callback URL",
|
||||
Description: "The application's registered callback URL has an invalid scheme.",
|
||||
Actions: []site.Action{
|
||||
{
|
||||
URL: accessURL.String(),
|
||||
Text: "Back to site",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
site.RenderOAuthAllowPage(rw, r, site.RenderOAuthAllowData{
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
CancelURI: cancel.String(),
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
// #nosec G203 -- The scheme is validated by
|
||||
// codersdk.ValidateRedirectURIScheme above.
|
||||
CancelURI: htmltemplate.URL(cancelURI),
|
||||
RedirectURI: r.URL.String(),
|
||||
CSRFToken: nosurf.Token(r),
|
||||
Username: ua.FriendlyName,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package oauth2provider_test
|
||||
|
||||
import (
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -20,7 +21,7 @@ func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
|
||||
|
||||
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
|
||||
AppName: "Test OAuth App",
|
||||
CancelURI: "https://coder.com/cancel",
|
||||
CancelURI: htmltemplate.URL("https://coder.com/cancel"),
|
||||
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
|
||||
CSRFToken: csrfFieldValue,
|
||||
Username: "test-user",
|
||||
|
||||
@@ -435,6 +435,9 @@ func convertProvisionerJobWithQueuePosition(pj database.GetProvisionerJobsByOrga
|
||||
if pj.WorkspaceID.Valid {
|
||||
job.Metadata.WorkspaceID = &pj.WorkspaceID.UUID
|
||||
}
|
||||
if pj.WorkspaceBuildTransition.Valid {
|
||||
job.Metadata.WorkspaceBuildTransition = codersdk.WorkspaceTransition(pj.WorkspaceBuildTransition.WorkspaceTransition)
|
||||
}
|
||||
return job
|
||||
}
|
||||
|
||||
|
||||
@@ -97,13 +97,14 @@ func TestProvisionerJobs(t *testing.T) {
|
||||
|
||||
// Verify that job metadata is correct.
|
||||
assert.Equal(t, job2.Metadata, codersdk.ProvisionerJobMetadata{
|
||||
TemplateVersionName: version.Name,
|
||||
TemplateID: template.ID,
|
||||
TemplateName: template.Name,
|
||||
TemplateDisplayName: template.DisplayName,
|
||||
TemplateIcon: template.Icon,
|
||||
WorkspaceID: &w.ID,
|
||||
WorkspaceName: w.Name,
|
||||
TemplateVersionName: version.Name,
|
||||
TemplateID: template.ID,
|
||||
TemplateName: template.Name,
|
||||
TemplateDisplayName: template.DisplayName,
|
||||
TemplateIcon: template.Icon,
|
||||
WorkspaceID: &w.ID,
|
||||
WorkspaceName: w.Name,
|
||||
WorkspaceBuildTransition: codersdk.WorkspaceTransitionStart,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
const ChatConfigEventChannel = "chat:config_change"
|
||||
|
||||
// HandleChatConfigEvent wraps a typed callback for ChatConfigEvent
|
||||
// messages, following the same pattern as HandleChatEvent.
|
||||
// messages, following the same pattern as HandleChatWatchEvent.
|
||||
func HandleChatConfigEvent(cb func(ctx context.Context, payload ChatConfigEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func ChatEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
|
||||
type ChatEvent struct {
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type ChatEventKind string
|
||||
|
||||
const (
|
||||
ChatEventKindStatusChange ChatEventKind = "status_change"
|
||||
ChatEventKindTitleChange ChatEventKind = "title_change"
|
||||
ChatEventKindCreated ChatEventKind = "created"
|
||||
ChatEventKindDeleted ChatEventKind = "deleted"
|
||||
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
|
||||
ChatEventKindActionRequired ChatEventKind = "action_required"
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// ChatWatchEventChannel returns the pubsub channel for chat
|
||||
// lifecycle events scoped to a single user.
|
||||
func ChatWatchEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
// HandleChatWatchEvent wraps a typed callback for
|
||||
// ChatWatchEvent messages delivered via pubsub.
|
||||
func HandleChatWatchEvent(cb func(ctx context.Context, payload codersdk.ChatWatchEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("chat watch event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("unmarshal chat watch event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,280 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// @Summary Create a new user secret
|
||||
// @ID create-a-new-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param request body codersdk.CreateUserSecretRequest true "Create secret request"
|
||||
// @Success 201 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [post]
|
||||
func (api *API) postUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
var req codersdk.CreateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Name is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Value == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Value is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := codersdk.UserSecretFilePathValid(req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
secret, err := api.Database.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Value: req.Value,
|
||||
ValueKeyID: sql.NullString{},
|
||||
EnvName: req.EnvName,
|
||||
FilePath: req.FilePath,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "A secret with that name, environment variable, or file path already exists.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error creating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary List user secrets
|
||||
// @ID list-user-secrets
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Success 200 {array} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [get]
|
||||
func (api *API) getUserSecrets(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
secrets, err := api.Database.ListUserSecrets(ctx, user.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error listing secrets.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecrets(secrets))
|
||||
}
|
||||
|
||||
// @Summary Get a user secret by name
|
||||
// @ID get-a-user-secret-by-name
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [get]
|
||||
func (api *API) getUserSecret(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
secret, err := api.Database.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Update a user secret
|
||||
// @ID update-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Param request body codersdk.UpdateUserSecretRequest true "Update secret request"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [patch]
|
||||
func (api *API) patchUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var req codersdk.UpdateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Value == nil && req.Description == nil && req.EnvName == nil && req.FilePath == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "At least one field must be provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(*req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
if err := codersdk.UserSecretFilePathValid(*req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
params := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
UpdateValue: req.Value != nil,
|
||||
Value: "",
|
||||
ValueKeyID: sql.NullString{},
|
||||
UpdateDescription: req.Description != nil,
|
||||
Description: "",
|
||||
UpdateEnvName: req.EnvName != nil,
|
||||
EnvName: "",
|
||||
UpdateFilePath: req.FilePath != nil,
|
||||
FilePath: "",
|
||||
}
|
||||
if req.Value != nil {
|
||||
params.Value = *req.Value
|
||||
}
|
||||
if req.Description != nil {
|
||||
params.Description = *req.Description
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
params.EnvName = *req.EnvName
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
params.FilePath = *req.FilePath
|
||||
}
|
||||
|
||||
secret, err := api.Database.UpdateUserSecretByUserIDAndName(ctx, params)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Update would conflict with an existing secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Delete a user secret
|
||||
// @ID delete-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 204
|
||||
// @Router /users/{user}/secrets/{name} [delete]
|
||||
func (api *API) deleteUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
rowsAffected, err := api.Database.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error deleting secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestPostUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
secret, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
Description: "Personal GitHub PAT",
|
||||
EnvName: "GITHUB_TOKEN",
|
||||
FilePath: "~/.github-token",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "github-token", secret.Name)
|
||||
assert.Equal(t, "Personal GitHub PAT", secret.Description)
|
||||
assert.Equal(t, "GITHUB_TOKEN", secret.EnvName)
|
||||
assert.Equal(t, "~/.github-token", secret.FilePath)
|
||||
assert.NotZero(t, secret.ID)
|
||||
assert.NotZero(t, secret.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("MissingName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Value: "some-value",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Name is required")
|
||||
})
|
||||
|
||||
t.Run("MissingValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "missing-value-secret",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Value is required")
|
||||
})
|
||||
|
||||
t.Run("DuplicateName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value2",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-1",
|
||||
Value: "value1",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-2",
|
||||
Value: "value2",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-2",
|
||||
Value: "value2",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "invalid-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "1INVALID",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ReservedEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "reserved-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "PATH",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("CoderPrefixEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "coder-prefix-secret",
|
||||
Value: "value",
|
||||
EnvName: "CODER_AGENT_TOKEN",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "bad-path-secret",
|
||||
Value: "value",
|
||||
FilePath: "relative/path",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Verify no secrets exist on a fresh user.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, secrets)
|
||||
|
||||
t.Run("WithSecrets", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-a",
|
||||
Value: "value-a",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-b",
|
||||
Value: "value-b",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 2)
|
||||
// Sorted by name.
|
||||
assert.Equal(t, "list-secret-a", secrets[0].Name)
|
||||
assert.Equal(t, "list-secret-b", secrets[1].Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
created, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "get-found-secret",
|
||||
Value: "my-value",
|
||||
EnvName: "GET_FOUND_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := client.UserSecretByName(ctx, codersdk.Me, "get-found-secret")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
assert.Equal(t, "get-found-secret", got.Name)
|
||||
assert.Equal(t, "GET_FOUND_SECRET", got.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.UserSecretByName(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPatchUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("UpdateDescription", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-desc-secret",
|
||||
Value: "my-value",
|
||||
Description: "original",
|
||||
EnvName: "PATCH_DESC_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
newDesc := "updated"
|
||||
updated, err := client.UpdateUserSecret(ctx, codersdk.Me, "patch-desc-secret", codersdk.UpdateUserSecretRequest{
|
||||
Description: &newDesc,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated", updated.Description)
|
||||
// Other fields unchanged.
|
||||
assert.Equal(t, "PATCH_DESC_ENV", updated.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NoFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-nofields-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "patch-nofields-secret", codersdk.UpdateUserSecretRequest{})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
newVal := "new-value"
|
||||
_, err := client.UpdateUserSecret(ctx, codersdk.Me, "nonexistent", codersdk.UpdateUserSecretRequest{
|
||||
Value: &newVal,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-1",
|
||||
Value: "value1",
|
||||
EnvName: "CONFLICT_TAKEN_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "CONFLICT_TAKEN_ENV"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-env-2", codersdk.UpdateUserSecretRequest{
|
||||
EnvName: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/conflict-taken",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "/tmp/conflict-taken"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-fp-2", codersdk.UpdateUserSecretRequest{
|
||||
FilePath: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "delete-me-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.DeleteUserSecret(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone.
|
||||
_, err = client.UserSecretByName(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
err := client.DeleteUserSecret(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
+600
-2
@@ -42,6 +42,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
maputil "github.com/coder/coder/v2/coderd/util/maps"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
@@ -181,8 +183,9 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
if logEntry.Level == "" {
|
||||
// Default to "info" to support older agents that didn't have the level field.
|
||||
logEntry.Level = codersdk.LogLevelInfo
|
||||
@@ -2392,3 +2395,598 @@ func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.Wor
|
||||
}
|
||||
return sdk
|
||||
}
|
||||
|
||||
// maxChatContextParts caps the number of parts per request to
|
||||
// prevent unbounded message payloads.
|
||||
const maxChatContextParts = 100
|
||||
|
||||
// maxChatContextFileBytes caps each context-file part to the same
|
||||
// 64KiB budget used when the agent reads instruction files from disk.
|
||||
const maxChatContextFileBytes = 64 * 1024
|
||||
|
||||
// maxChatContextRequestBodyBytes caps the JSON request body size for
|
||||
// agent-added context to roughly the same per-part budget used when
|
||||
// reading instruction files from disk.
|
||||
const maxChatContextRequestBodyBytes int64 = maxChatContextParts * maxChatContextFileBytes
|
||||
|
||||
// sanitizeWorkspaceAgentContextFileContent applies prompt
|
||||
// sanitization, then enforces the 64KiB per-file budget. The
|
||||
// truncated flag is preserved when the caller already capped the
|
||||
// file before sending it.
|
||||
func sanitizeWorkspaceAgentContextFileContent(
|
||||
content string,
|
||||
truncated bool,
|
||||
) (string, bool) {
|
||||
content = chatd.SanitizePromptText(content)
|
||||
if len(content) > maxChatContextFileBytes {
|
||||
content = content[:maxChatContextFileBytes]
|
||||
truncated = true
|
||||
}
|
||||
return content, truncated
|
||||
}
|
||||
|
||||
// readChatContextBody reads and validates the request body for chat
|
||||
// context endpoints. It handles MaxBytesReader wrapping, error
|
||||
// responses, and body rewind. If the body is empty or whitespace-only
|
||||
// and allowEmpty is true, it returns false without writing an error.
|
||||
//
|
||||
//nolint:revive // Add and clear endpoints only differ by empty-body handling.
|
||||
func readChatContextBody(ctx context.Context, rw http.ResponseWriter, r *http.Request, dst any, allowEmpty bool) bool {
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, maxChatContextRequestBodyBytes)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "Request body too large.",
|
||||
Detail: fmt.Sprintf("Maximum request body size is %d bytes.", maxChatContextRequestBodyBytes),
|
||||
})
|
||||
return false
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return false
|
||||
}
|
||||
if allowEmpty && len(bytes.TrimSpace(body)) == 0 {
|
||||
r.Body = http.NoBody
|
||||
return false
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return httpapi.Read(ctx, rw, r, dst)
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.AddChatContextRequest
|
||||
if !readChatContextBody(ctx, rw, r, &req, false) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) > maxChatContextParts {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Too many context parts (%d). Maximum is %d.", len(req.Parts), maxChatContextParts),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Filter to only non-empty context-file and skill parts.
|
||||
filtered := chatd.FilterContextParts(req.Parts, false)
|
||||
if len(filtered) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
req.Parts = filtered
|
||||
responsePartCount := 0
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
// We verify agent-to-chat ownership explicitly below.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Stamp each persisted part with the agent identity. Context-file
|
||||
// parts also get server-authoritative workspace metadata.
|
||||
directory := workspaceAgent.ExpandedDirectory
|
||||
if directory == "" {
|
||||
directory = workspaceAgent.Directory
|
||||
}
|
||||
for i := range req.Parts {
|
||||
req.Parts[i].ContextFileAgentID = uuid.NullUUID{
|
||||
UUID: workspaceAgent.ID,
|
||||
Valid: true,
|
||||
}
|
||||
if req.Parts[i].Type != codersdk.ChatMessagePartTypeContextFile {
|
||||
continue
|
||||
}
|
||||
req.Parts[i].ContextFileContent, req.Parts[i].ContextFileTruncated = sanitizeWorkspaceAgentContextFileContent(
|
||||
req.Parts[i].ContextFileContent,
|
||||
req.Parts[i].ContextFileTruncated,
|
||||
)
|
||||
req.Parts[i].ContextFileOS = workspaceAgent.OperatingSystem
|
||||
req.Parts[i].ContextFileDirectory = directory
|
||||
}
|
||||
req.Parts = chatd.FilterContextParts(req.Parts, false)
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
responsePartCount = len(req.Parts)
|
||||
|
||||
// Skill-only messages need a sentinel context-file part so the turn
|
||||
// pipeline trusts the associated skill metadata.
|
||||
req.Parts = prependAgentChatContextSentinelIfNeeded(
|
||||
req.Parts,
|
||||
workspaceAgent.ID,
|
||||
workspaceAgent.OperatingSystem,
|
||||
directory,
|
||||
)
|
||||
|
||||
content, err := chatprompt.MarshalParts(req.Parts)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal context parts.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Database.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(sysCtx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != workspaceAgent.ID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspace.OwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleChatMessageInsertParams(
|
||||
chat.ID,
|
||||
database.ChatMessageRoleUser,
|
||||
content,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
locked.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
uuid.Nil,
|
||||
)); err != nil {
|
||||
return xerrors.Errorf("insert context message: %w", err)
|
||||
}
|
||||
if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, tx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("rebuild injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to persist context message.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.AddChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
Count: responsePartCount,
|
||||
})
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentClearChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.ClearChatContextRequest
|
||||
populated := readChatContextBody(ctx, rw, r, &req, true)
|
||||
if !populated && r.Body != http.NoBody {
|
||||
return
|
||||
}
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
// Zero active chats is not an error for clear.
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{})
|
||||
return
|
||||
}
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = clearAgentChatContext(sysCtx, api.Database, chat.ID, workspaceAgent.ID, workspace.OwnerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to clear context from chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
errNoActiveChats = xerrors.New("no active chats found")
|
||||
errChatNotFound = xerrors.New("chat not found")
|
||||
errChatNotActive = xerrors.New("chat is not active")
|
||||
errChatDoesNotBelongToAgent = xerrors.New("chat does not belong to this agent")
|
||||
errChatDoesNotBelongToWorkspaceOwner = xerrors.New("chat does not belong to this workspace owner")
|
||||
)
|
||||
|
||||
type multipleActiveChatsError struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (e *multipleActiveChatsError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"multiple active chats (%d) found for this agent, specify a chat ID",
|
||||
e.count,
|
||||
)
|
||||
}
|
||||
|
||||
func resolveDefaultAgentChat(chats []database.Chat) (database.Chat, error) {
|
||||
switch len(chats) {
|
||||
case 0:
|
||||
return database.Chat{}, errNoActiveChats
|
||||
case 1:
|
||||
return chats[0], nil
|
||||
}
|
||||
|
||||
var rootChat *database.Chat
|
||||
for i := range chats {
|
||||
chat := &chats[i]
|
||||
if chat.ParentChatID.Valid {
|
||||
continue
|
||||
}
|
||||
if rootChat != nil {
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
rootChat = chat
|
||||
}
|
||||
if rootChat != nil {
|
||||
return *rootChat, nil
|
||||
}
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
|
||||
// resolveAgentChat finds the target chat from either an explicit ID
|
||||
// or auto-detection via the agent's active chats.
|
||||
func resolveAgentChat(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
explicitChatID uuid.UUID,
|
||||
) (database.Chat, error) {
|
||||
if explicitChatID == uuid.Nil {
|
||||
chats, err := db.GetActiveChatsByAgentID(ctx, agentID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("list active chats: %w", err)
|
||||
}
|
||||
ownerChats := make([]database.Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
continue
|
||||
}
|
||||
ownerChats = append(ownerChats, chat)
|
||||
}
|
||||
return resolveDefaultAgentChat(ownerChats)
|
||||
}
|
||||
|
||||
chat, err := db.GetChatByID(ctx, explicitChatID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return database.Chat{}, errChatNotFound
|
||||
}
|
||||
return database.Chat{}, xerrors.Errorf("get chat by id: %w", err)
|
||||
}
|
||||
if !chat.AgentID.Valid || chat.AgentID.UUID != agentID {
|
||||
return database.Chat{}, errChatDoesNotBelongToAgent
|
||||
}
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
return database.Chat{}, errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if !isActiveAgentChat(chat) {
|
||||
return database.Chat{}, errChatNotActive
|
||||
}
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
func isActiveAgentChat(chat database.Chat) bool {
|
||||
if chat.Archived {
|
||||
return false
|
||||
}
|
||||
|
||||
switch chat.Status {
|
||||
case database.ChatStatusWaiting,
|
||||
database.ChatStatusPending,
|
||||
database.ChatStatusRunning,
|
||||
database.ChatStatusPaused,
|
||||
database.ChatStatusRequiresAction:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func clearAgentChatContext(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
) error {
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(ctx, chatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != agentID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspaceOwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
messages, err := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
hadInjectedContext := locked.LastInjectedContext.Valid
|
||||
var skillOnlyMessageIDs []int64
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
hasContextFile := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeContextFile)
|
||||
hasSkill := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeSkill)
|
||||
if hasContextFile || hasSkill {
|
||||
hadInjectedContext = true
|
||||
}
|
||||
if hasSkill && !hasContextFile {
|
||||
skillOnlyMessageIDs = append(skillOnlyMessageIDs, msg.ID)
|
||||
}
|
||||
}
|
||||
if !hadInjectedContext {
|
||||
return nil
|
||||
}
|
||||
if err := tx.SoftDeleteContextFileMessages(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("soft delete context-file messages: %w", err)
|
||||
}
|
||||
for _, messageID := range skillOnlyMessageIDs {
|
||||
if err := tx.SoftDeleteChatMessageByID(ctx, messageID); err != nil {
|
||||
return xerrors.Errorf("soft delete context message %d: %w", messageID, err)
|
||||
}
|
||||
}
|
||||
// Reset provider-side Responses chaining so the next turn replays
|
||||
// the post-clear history instead of inheriting cleared context.
|
||||
if err := tx.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("clear provider response chain: %w", err)
|
||||
}
|
||||
// Clear the injected-context cache inside the transaction so it is
|
||||
// atomic with the soft-deletes.
|
||||
param, err := chatd.BuildLastInjectedContext(nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
if _, err := tx.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// prependAgentChatContextSentinelIfNeeded adds an empty context-file
|
||||
// part when the request only carries skills. The turn pipeline uses
|
||||
// the sentinel's agent metadata to trust the skill parts.
|
||||
func prependAgentChatContextSentinelIfNeeded(
|
||||
parts []codersdk.ChatMessagePart,
|
||||
agentID uuid.UUID,
|
||||
operatingSystem string,
|
||||
directory string,
|
||||
) []codersdk.ChatMessagePart {
|
||||
hasContextFile := false
|
||||
hasSkill := false
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasContextFile = true
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
hasSkill = true
|
||||
}
|
||||
if hasContextFile && hasSkill {
|
||||
return parts
|
||||
}
|
||||
}
|
||||
if !hasSkill || hasContextFile {
|
||||
return parts
|
||||
}
|
||||
return append([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: chatd.AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
ContextFileOS: operatingSystem,
|
||||
ContextFileDirectory: directory,
|
||||
}}, parts...)
|
||||
}
|
||||
|
||||
func sortChatMessagesByCreatedAtAndID(messages []database.ChatMessage) {
|
||||
sort.SliceStable(messages, func(i, j int) bool {
|
||||
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
|
||||
return messages[i].ID < messages[j].ID
|
||||
}
|
||||
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
|
||||
})
|
||||
}
|
||||
|
||||
// updateAgentChatLastInjectedContextFromMessages rebuilds the
|
||||
// injected-context cache from all persisted context-file and skill parts.
|
||||
func updateAgentChatLastInjectedContextFromMessages(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
) error {
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("load context messages for injected context: %w", err)
|
||||
}
|
||||
|
||||
sortChatMessagesByCreatedAtAndID(messages)
|
||||
|
||||
parts, err := chatd.CollectContextPartsFromMessages(ctx, logger, messages, true)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("collect injected context parts: %w", err)
|
||||
}
|
||||
parts = chatd.FilterContextPartsToLatestAgent(parts)
|
||||
|
||||
param, err := chatd.BuildLastInjectedContext(parts)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
if _, err := db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func messageHasPartTypes(raw []byte, types ...codersdk.ChatMessagePartType) bool {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
for _, typ := range types {
|
||||
if part.Type == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeAgentChatError translates resolveAgentChat errors to HTTP
|
||||
// responses.
|
||||
func writeAgentChatError(
|
||||
ctx context.Context,
|
||||
rw http.ResponseWriter,
|
||||
err error,
|
||||
) {
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "No active chats found for this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotFound) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Chat not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToAgent) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this workspace owner.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotActive) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Cannot modify context: this chat is no longer active.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var multipleErr *multipleActiveChatsError
|
||||
if errors.As(err, &multipleErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to resolve chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestActiveAgentChatDefinitionsAgree(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
workspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: owner.ID,
|
||||
}).WithAgent().Do()
|
||||
modelConfig := insertAgentChatTestModelConfig(ctx, t, db, owner.ID)
|
||||
|
||||
insertedChats := make([]database.Chat, 0, len(database.AllChatStatusValues())*2)
|
||||
for _, archived := range []bool{false, true} {
|
||||
for _, status := range database.AllChatStatusValues() {
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: status,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: fmt.Sprintf("%s-archived-%t", status, archived),
|
||||
AgentID: uuid.NullUUID{UUID: workspace.Agents[0].ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if archived {
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
insertedChats = append(insertedChats, chat)
|
||||
}
|
||||
}
|
||||
|
||||
activeChats, err := db.GetActiveChatsByAgentID(ctx, workspace.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeByID := make(map[uuid.UUID]bool, len(activeChats))
|
||||
for _, chat := range activeChats {
|
||||
activeByID[chat.ID] = true
|
||||
}
|
||||
|
||||
for _, chat := range insertedChats {
|
||||
require.Equalf(
|
||||
t,
|
||||
isActiveAgentChat(chat),
|
||||
activeByID[chat.ID],
|
||||
"status=%s archived=%t",
|
||||
chat.Status,
|
||||
chat.Archived,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestUpdateAgentChatLastInjectedContextFromMessagesUsesMessageIDTieBreaker(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
createdAt := time.Date(2026, time.April, 9, 13, 0, 0, 0, time.UTC)
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
|
||||
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
newContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{
|
||||
{
|
||||
ID: 2,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: newContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: oldContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
require.Equal(t, chatID, arg.ID)
|
||||
require.True(t, arg.LastInjectedContext.Valid)
|
||||
var cached []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(arg.LastInjectedContext.RawMessage, &cached))
|
||||
require.Len(t, cached, 1)
|
||||
require.Equal(t, "/new/AGENTS.md", cached[0].ContextFilePath)
|
||||
require.Equal(t, uuid.NullUUID{UUID: newAgentID, Valid: true}, cached[0].ContextFileAgentID)
|
||||
return database.Chat{}, nil
|
||||
},
|
||||
)
|
||||
|
||||
err = updateAgentChatLastInjectedContextFromMessages(
|
||||
context.Background(),
|
||||
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
db,
|
||||
chatID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func insertAgentChatTestModelConfig(
|
||||
ctx context.Context,
|
||||
t testing.TB,
|
||||
db database.Store,
|
||||
userID uuid.UUID,
|
||||
) database.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
createdBy := uuid.NullUUID{UUID: userID, Valid: true}
|
||||
|
||||
_, err := db.InsertChatProvider(sysCtx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: createdBy,
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(sysCtx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: createdBy,
|
||||
UpdatedBy: createdBy,
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return model
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,7 +91,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
require.Equal(t, tmpDir, workspace.LatestBuild.Resources[0].Agents[0].Directory)
|
||||
_, err = anotherClient.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
require.False(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
})
|
||||
t.Run("HasFallbackTroubleshootingURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -260,6 +260,50 @@ func TestWorkspaceAgentLogs(t *testing.T) {
|
||||
require.Equal(t, "testing", logChunk[0].Output)
|
||||
require.Equal(t, "testing2", logChunk[1].Output)
|
||||
})
|
||||
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
rawOutput := "before\x00after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
|
||||
Logs: []agentsdk.Log{
|
||||
{
|
||||
CreatedAt: dbtime.Now(),
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
|
||||
|
||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = closer.Close()
|
||||
}()
|
||||
|
||||
var logChunk []codersdk.WorkspaceAgentLog
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case logChunk = <-logs:
|
||||
}
|
||||
require.NoError(t, ctx.Err())
|
||||
require.Len(t, logChunk, 1)
|
||||
require.Equal(t, sanitizedOutput, logChunk[0].Output)
|
||||
})
|
||||
t.Run("Close logs on outdated build", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
+69
-12
@@ -213,6 +213,39 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Healthy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: echo.ProvisionGraphWithAgent(authToken),
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return assert.NoError(t, err) && workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Connecting", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
@@ -247,10 +280,10 @@ func TestWorkspace(t *testing.T) {
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{agent.ID}, workspace.Health.FailingAgents)
|
||||
assert.False(t, agent.Health.Healthy)
|
||||
assert.Equal(t, "agent has not yet connected", agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Unhealthy", func(t *testing.T) {
|
||||
@@ -302,6 +335,7 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
a1AuthToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -313,7 +347,9 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "a1",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: a1AuthToken,
|
||||
},
|
||||
}, {
|
||||
Id: uuid.NewString(),
|
||||
Name: "a2",
|
||||
@@ -330,13 +366,21 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, a1AuthToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return assert.NoError(t, err) && !workspace.Health.Healthy
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// Wait for the mixed state: a1 connected (healthy)
|
||||
// and workspace unhealthy (because a2 timed out).
|
||||
agent1 := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return agent1.Health.Healthy && !workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
@@ -360,6 +404,7 @@ func TestWorkspace(t *testing.T) {
|
||||
// disconnected, but this should not make the workspace unhealthy.
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -371,7 +416,9 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "parent",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
@@ -383,14 +430,23 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Get the workspace and parent agent.
|
||||
workspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
parentAgent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy initially")
|
||||
// Wait for the parent agent to connect and be healthy.
|
||||
var parentAgent codersdk.WorkspaceAgent
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
parentAgent = workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return parentAgent.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy")
|
||||
|
||||
// Create a sub-agent with a short connection timeout so it becomes
|
||||
// unhealthy quickly (simulating a devcontainer rebuild scenario).
|
||||
@@ -404,6 +460,7 @@ func TestWorkspace(t *testing.T) {
|
||||
// Wait for the sub-agent to become unhealthy due to timeout.
|
||||
var subAgentUnhealthy bool
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
+129
-53
@@ -73,9 +73,9 @@ const (
|
||||
|
||||
// maxConcurrentRecordingUploads caps the number of recording
|
||||
// stop-and-store operations that can run concurrently. Each
|
||||
// slot buffers up to MaxRecordingSize (100 MB) in memory, so
|
||||
// this value implicitly bounds memory to roughly
|
||||
// maxConcurrentRecordingUploads * 100 MB.
|
||||
// slot buffers up to MaxRecordingSize + MaxThumbnailSize
|
||||
// (110 MB) in memory, so this value implicitly bounds memory
|
||||
// to roughly maxConcurrentRecordingUploads * 110 MB.
|
||||
maxConcurrentRecordingUploads = 25
|
||||
|
||||
// staleRecoveryIntervalDivisor determines how often the stale
|
||||
@@ -996,7 +996,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
return database.Chat{}, txErr
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindCreated, nil)
|
||||
p.signalWake()
|
||||
return chat, nil
|
||||
}
|
||||
@@ -1158,7 +1158,7 @@ func (p *Server) SendMessage(
|
||||
|
||||
p.publishMessage(opts.ChatID, result.Message)
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
return result, nil
|
||||
}
|
||||
@@ -1301,7 +1301,7 @@ func (p *Server) EditMessage(
|
||||
QueueUpdate: true,
|
||||
})
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -1355,10 +1355,10 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
|
||||
if interrupted {
|
||||
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
|
||||
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted)
|
||||
p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1373,7 +1373,7 @@ func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
ctx,
|
||||
chat.ID,
|
||||
"unarchive",
|
||||
coderdpubsub.ChatEventKindCreated,
|
||||
codersdk.ChatWatchEventKindCreated,
|
||||
p.db.UnarchiveChatByID,
|
||||
)
|
||||
}
|
||||
@@ -1382,7 +1382,7 @@ func (p *Server) applyChatLifecycleTransition(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
action string,
|
||||
kind coderdpubsub.ChatEventKind,
|
||||
kind codersdk.ChatWatchEventKind,
|
||||
transition func(context.Context, uuid.UUID) ([]database.Chat, error),
|
||||
) error {
|
||||
updatedChats, err := transition(ctx, chatID)
|
||||
@@ -1545,7 +1545,7 @@ func (p *Server) PromoteQueued(
|
||||
})
|
||||
p.publishMessage(opts.ChatID, promoted)
|
||||
p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -2092,7 +2092,7 @@ func (p *Server) regenerateChatTitleWithStore(
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindTitleChange, nil)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
@@ -2347,7 +2347,7 @@ func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database
|
||||
return database.Chat{}, err
|
||||
}
|
||||
p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
@@ -2461,6 +2461,33 @@ type chainModeInfo struct {
|
||||
// trailingUserCount is the number of contiguous user messages
|
||||
// at the end of the conversation that form the current turn.
|
||||
trailingUserCount int
|
||||
// contributingTrailingUserCount counts the trailing user
|
||||
// messages that materially change the provider input.
|
||||
contributingTrailingUserCount int
|
||||
}
|
||||
|
||||
func userMessageContributesToChainMode(msg database.ChatMessage) bool {
|
||||
parts, err := chatprompt.ParseContent(msg)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeText,
|
||||
codersdk.ChatMessagePartTypeReasoning:
|
||||
if strings.TrimSpace(part.Text) != "" {
|
||||
return true
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeFile,
|
||||
codersdk.ChatMessagePartTypeFileReference:
|
||||
return true
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
if part.ContextFileContent != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveChainMode scans DB messages from the end to count trailing user
|
||||
@@ -2470,11 +2497,13 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
|
||||
var info chainModeInfo
|
||||
i := len(messages) - 1
|
||||
for ; i >= 0; i-- {
|
||||
if messages[i].Role == database.ChatMessageRoleUser {
|
||||
info.trailingUserCount++
|
||||
continue
|
||||
if messages[i].Role != database.ChatMessageRoleUser {
|
||||
break
|
||||
}
|
||||
info.trailingUserCount++
|
||||
if userMessageContributesToChainMode(messages[i]) {
|
||||
info.contributingTrailingUserCount++
|
||||
}
|
||||
break
|
||||
}
|
||||
for ; i >= 0; i-- {
|
||||
switch messages[i].Role {
|
||||
@@ -2497,15 +2526,15 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// filterPromptForChainMode keeps only system messages and the last
|
||||
// trailingUserCount user messages from the prompt. Assistant and tool
|
||||
// messages are dropped because the provider already has them via the
|
||||
// previous_response_id chain.
|
||||
// filterPromptForChainMode keeps only system messages and the trailing
|
||||
// user messages that still contribute model-visible content to the
|
||||
// current turn. Assistant and tool messages are dropped because the
|
||||
// provider already has them via the previous_response_id chain.
|
||||
func filterPromptForChainMode(
|
||||
prompt []fantasy.Message,
|
||||
trailingUserCount int,
|
||||
info chainModeInfo,
|
||||
) []fantasy.Message {
|
||||
if trailingUserCount <= 0 {
|
||||
if info.contributingTrailingUserCount <= 0 {
|
||||
return prompt
|
||||
}
|
||||
|
||||
@@ -2516,7 +2545,12 @@ func filterPromptForChainMode(
|
||||
}
|
||||
}
|
||||
|
||||
usersToSkip := totalUsers - trailingUserCount
|
||||
// Prompt construction already drops user turns with no model-visible
|
||||
// content, such as skill-only sentinel messages. That means the user
|
||||
// count here stays aligned with contributingTrailingUserCount even
|
||||
// when non-contributing DB turns are interleaved in the trailing
|
||||
// block.
|
||||
usersToSkip := totalUsers - info.contributingTrailingUserCount
|
||||
if usersToSkip < 0 {
|
||||
usersToSkip = 0
|
||||
}
|
||||
@@ -2562,6 +2596,28 @@ func appendChatMessage(
|
||||
params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID)
|
||||
}
|
||||
|
||||
// BuildSingleChatMessageInsertParams creates batch insert params for one
|
||||
// message using the shared chat message builder.
|
||||
func BuildSingleChatMessageInsertParams(
|
||||
chatID uuid.UUID,
|
||||
role database.ChatMessageRole,
|
||||
content pqtype.NullRawMessage,
|
||||
visibility database.ChatMessageVisibility,
|
||||
modelConfigID uuid.UUID,
|
||||
contentVersion int16,
|
||||
createdBy uuid.UUID,
|
||||
) database.InsertChatMessagesParams {
|
||||
params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: chatID,
|
||||
}
|
||||
msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion)
|
||||
if createdBy != uuid.Nil {
|
||||
msg = msg.withCreatedBy(createdBy)
|
||||
}
|
||||
appendChatMessage(¶ms, msg)
|
||||
return params
|
||||
}
|
||||
|
||||
func insertUserMessageAndSetPending(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
@@ -3571,7 +3627,7 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C
|
||||
}
|
||||
|
||||
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) {
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) {
|
||||
for _, chat := range chats {
|
||||
p.publishChatPubsubEvent(chat, kind, nil)
|
||||
}
|
||||
@@ -3579,7 +3635,7 @@ func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsu
|
||||
|
||||
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
|
||||
// pubsub so that all replicas can push updates to watching clients.
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
if p.pubsub == nil {
|
||||
return
|
||||
}
|
||||
@@ -3591,7 +3647,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
|
||||
if diffStatus != nil {
|
||||
sdkChat.DiffStatus = diffStatus
|
||||
}
|
||||
event := coderdpubsub.ChatEvent{
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: kind,
|
||||
Chat: sdkChat,
|
||||
}
|
||||
@@ -3603,7 +3659,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("kind", kind),
|
||||
@@ -3636,8 +3692,8 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
|
||||
toolCalls := pendingToStreamToolCalls(pending)
|
||||
sdkChat := db2sdk.Chat(chat, nil, nil)
|
||||
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindActionRequired,
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindActionRequired,
|
||||
Chat: sdkChat,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
@@ -3649,7 +3705,7 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
@@ -3677,7 +3733,7 @@ func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID)
|
||||
}
|
||||
|
||||
sdkStatus := db2sdk.ChatDiffStatus(chatID, &dbStatus)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDiffStatusChange, &sdkStatus)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindDiffStatusChange, &sdkStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4159,7 +4215,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
if title, ok := generatedTitle.Load(); ok {
|
||||
updatedChat.Title = title
|
||||
}
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
|
||||
// When the chat is parked in requires_action,
|
||||
// publish the stream event and global pubsub event
|
||||
@@ -4430,13 +4486,21 @@ func (p *Server) runChat(
|
||||
// the workspace agent has changed (e.g. workspace rebuilt).
|
||||
needsInstructionPersist := false
|
||||
hasContextFiles := false
|
||||
persistedSkills := skillsFromParts(messages)
|
||||
latestInjectedAgentID, hasLatestInjectedAgent := latestContextAgentID(messages)
|
||||
currentWorkspaceAgentID := uuid.Nil
|
||||
hasCurrentWorkspaceAgent := false
|
||||
if chat.WorkspaceID.Valid {
|
||||
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
|
||||
currentWorkspaceAgentID = agent.ID
|
||||
hasCurrentWorkspaceAgent = true
|
||||
}
|
||||
persistedAgentID, found := contextFileAgentID(messages)
|
||||
hasContextFiles = found
|
||||
if !hasContextFiles {
|
||||
if !hasPersistedInstructionFiles(messages) {
|
||||
needsInstructionPersist = true
|
||||
} else if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil && agent.ID != persistedAgentID {
|
||||
// Agent changed — persist fresh instruction files.
|
||||
} else if hasCurrentWorkspaceAgent && currentWorkspaceAgentID != persistedAgentID {
|
||||
// Agent changed. Persist fresh instruction files.
|
||||
// Old context-file messages remain in the conversation
|
||||
// to preserve the prompt cache prefix.
|
||||
needsInstructionPersist = true
|
||||
@@ -4459,7 +4523,8 @@ func (p *Server) runChat(
|
||||
if needsInstructionPersist {
|
||||
g2.Go(func() error {
|
||||
var persistErr error
|
||||
instruction, skills, persistErr = p.persistInstructionFiles(
|
||||
var discoveredSkills []chattool.SkillMeta
|
||||
instruction, discoveredSkills, persistErr = p.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
modelConfig.ID,
|
||||
@@ -4471,6 +4536,12 @@ func (p *Server) runChat(
|
||||
return workspaceCtx.getWorkspaceConn(instructionCtx)
|
||||
},
|
||||
)
|
||||
skills = selectSkillMetasForInstructionRefresh(
|
||||
persistedSkills,
|
||||
discoveredSkills,
|
||||
uuid.NullUUID{UUID: currentWorkspaceAgentID, Valid: hasCurrentWorkspaceAgent},
|
||||
uuid.NullUUID{UUID: latestInjectedAgentID, Valid: hasLatestInjectedAgent},
|
||||
)
|
||||
if persistErr != nil {
|
||||
p.logger.Warn(ctx, "failed to persist instruction files",
|
||||
slog.F("chat_id", chat.ID),
|
||||
@@ -4485,7 +4556,7 @@ func (p *Server) runChat(
|
||||
// re-injected via InsertSystem after compaction drops
|
||||
// those messages. No workspace dial needed.
|
||||
instruction = instructionFromContextFiles(messages)
|
||||
skills = skillsFromParts(messages)
|
||||
skills = persistedSkills
|
||||
}
|
||||
g2.Go(func() error {
|
||||
resolvedUserPrompt = p.resolveUserPrompt(ctx, chat.OwnerID)
|
||||
@@ -5103,14 +5174,14 @@ func (p *Server) runChat(
|
||||
// assistant and tool messages that the provider already has.
|
||||
chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) &&
|
||||
chainInfo.previousResponseID != "" &&
|
||||
chainInfo.trailingUserCount > 0 &&
|
||||
chainInfo.contributingTrailingUserCount > 0 &&
|
||||
chainInfo.modelConfigID == modelConfig.ID
|
||||
if chainModeActive {
|
||||
providerOptions = chatprovider.CloneWithPreviousResponseID(
|
||||
providerOptions,
|
||||
chainInfo.previousResponseID,
|
||||
)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo)
|
||||
}
|
||||
err = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
Model: model,
|
||||
@@ -5164,7 +5235,7 @@ func (p *Server) runChat(
|
||||
if chainModeActive {
|
||||
reloadedPrompt = filterPromptForChainMode(
|
||||
reloadedPrompt,
|
||||
chainInfo.trailingUserCount,
|
||||
chainInfo,
|
||||
)
|
||||
}
|
||||
return reloadedPrompt, nil
|
||||
@@ -5537,8 +5608,9 @@ func refreshChatWorkspaceSnapshot(
|
||||
}
|
||||
|
||||
// contextFileAgentID extracts the workspace agent ID from the most
|
||||
// recent persisted context-file parts. Returns uuid.Nil, false if no
|
||||
// context-file parts exist.
|
||||
// recent persisted instruction-file parts. The skill-only sentinel is
|
||||
// ignored because it does not represent persisted instruction content.
|
||||
// Returns uuid.Nil, false if no instruction-file parts exist.
|
||||
func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
||||
var lastID uuid.UUID
|
||||
found := false
|
||||
@@ -5551,11 +5623,14 @@ func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
||||
continue
|
||||
}
|
||||
for _, p := range parts {
|
||||
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFileAgentID.Valid {
|
||||
lastID = p.ContextFileAgentID.UUID
|
||||
found = true
|
||||
break
|
||||
if p.Type != codersdk.ChatMessagePartTypeContextFile ||
|
||||
!p.ContextFileAgentID.Valid ||
|
||||
p.ContextFilePath == AgentChatContextSentinelPath {
|
||||
continue
|
||||
}
|
||||
lastID = p.ContextFileAgentID.UUID
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return lastID, found
|
||||
@@ -5625,13 +5700,14 @@ func (p *Server) persistInstructionFiles(
|
||||
// agent cannot know its own UUID, OS metadata, or
|
||||
// directory — those are added here at the trust boundary.
|
||||
var discoveredSkills []chattool.SkillMeta
|
||||
var hasContent bool
|
||||
var hasContent, hasContextFilePart bool
|
||||
agentID := uuid.NullUUID{UUID: agent.ID, Valid: true}
|
||||
|
||||
for i := range agentParts {
|
||||
agentParts[i].ContextFileAgentID = agentID
|
||||
switch agentParts[i].Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasContextFilePart = true
|
||||
agentParts[i].ContextFileContent = SanitizePromptText(agentParts[i].ContextFileContent)
|
||||
agentParts[i].ContextFileOS = agent.OperatingSystem
|
||||
agentParts[i].ContextFileDirectory = directory
|
||||
@@ -5652,13 +5728,13 @@ func (p *Server) persistInstructionFiles(
|
||||
if !workspaceConnOK {
|
||||
return "", nil, nil
|
||||
}
|
||||
// Persist a sentinel (plus any skill-only parts) so
|
||||
// subsequent turns skip the workspace agent dial.
|
||||
if len(agentParts) == 0 {
|
||||
agentParts = []codersdk.ChatMessagePart{{
|
||||
// Persist a blank context-file marker (plus any skill-only
|
||||
// parts) so subsequent turns skip the workspace agent dial.
|
||||
if !hasContextFilePart {
|
||||
agentParts = append([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFileAgentID: agentID,
|
||||
}}
|
||||
}}, agentParts...)
|
||||
}
|
||||
content, err := chatprompt.MarshalParts(agentParts)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -70,14 +71,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
updatedChat.Title = wantTitle
|
||||
|
||||
messageEvents := make(chan struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}, 1)
|
||||
cancelSub, err := pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
|
||||
coderdpubsub.ChatWatchEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
messageEvents <- struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}{payload: payload, err: err}
|
||||
}),
|
||||
@@ -183,7 +184,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
select {
|
||||
case event := <-messageEvents:
|
||||
require.NoError(t, event.err)
|
||||
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, chatID, event.payload.Chat.ID)
|
||||
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
||||
case <-time.After(time.Second):
|
||||
@@ -233,14 +234,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
unlockedChat.StartedAt = sql.NullTime{}
|
||||
|
||||
messageEvents := make(chan struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}, 1)
|
||||
cancelSub, err := pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
|
||||
coderdpubsub.ChatWatchEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
messageEvents <- struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}{payload: payload, err: err}
|
||||
}),
|
||||
@@ -372,7 +373,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
select {
|
||||
case event := <-messageEvents:
|
||||
require.NoError(t, event.err)
|
||||
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, chatID, event.payload.Chat.ID)
|
||||
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
||||
case <-time.After(time.Second):
|
||||
@@ -703,7 +704,33 @@ func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
|
||||
gomock.Any(),
|
||||
agentID,
|
||||
).Return(workspaceAgent, nil).Times(1)
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(),
|
||||
gomock.Cond(func(x any) bool {
|
||||
arg, ok := x.(database.InsertChatMessagesParams)
|
||||
if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 {
|
||||
return false
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
foundMarker := false
|
||||
foundSkill := false
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" {
|
||||
foundMarker = true
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
|
||||
foundSkill = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return foundMarker && foundSkill
|
||||
}),
|
||||
).Return(nil, nil).Times(1)
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
||||
gomock.Cond(func(x any) bool {
|
||||
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
||||
@@ -2020,6 +2047,30 @@ func TestContextFileAgentID(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
instructionAgentID := uuid.New()
|
||||
sentinelAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/workspace/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true},
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: sentinelAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
}}),
|
||||
}
|
||||
id, ok := contextFileAgentID(msgs)
|
||||
require.Equal(t, instructionAgentID, id)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("SentinelWithoutAgentID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []database.ChatMessage{
|
||||
@@ -2036,6 +2087,492 @@ func TestContextFileAgentID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasPersistedInstructionFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
agentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}}),
|
||||
}
|
||||
require.False(t, hasPersistedInstructionFiles(msgs))
|
||||
})
|
||||
|
||||
t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
agentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/workspace/AGENTS.md",
|
||||
ContextFileContent: "repo instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
||||
}}),
|
||||
}
|
||||
require.True(t, hasPersistedInstructionFiles(msgs))
|
||||
})
|
||||
}
|
||||
|
||||
func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileOS: "darwin",
|
||||
ContextFileDirectory: "/old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}}),
|
||||
}
|
||||
|
||||
got := instructionFromContextFiles(msgs)
|
||||
require.Contains(t, got, "new instructions")
|
||||
require.Contains(t, got, "Operating System: linux")
|
||||
require.Contains(t, got, "Working Directory: /new")
|
||||
require.NotContains(t, got, "old instructions")
|
||||
require.NotContains(t, got, "Operating System: darwin")
|
||||
}
|
||||
|
||||
func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/legacy/AGENTS.md",
|
||||
ContextFileContent: "legacy instructions",
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileOS: "darwin",
|
||||
ContextFileDirectory: "/old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}}),
|
||||
}
|
||||
|
||||
got := instructionFromContextFiles(msgs)
|
||||
require.Contains(t, got, "legacy instructions")
|
||||
require.Contains(t, got, "new instructions")
|
||||
require.Contains(t, got, "Operating System: linux")
|
||||
require.Contains(t, got, "Working Directory: /new")
|
||||
require.NotContains(t, got, "old instructions")
|
||||
require.NotContains(t, got, "Operating System: darwin")
|
||||
}
|
||||
|
||||
func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-legacy",
|
||||
SkillDir: "/skills/repo-helper-legacy",
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-old",
|
||||
SkillDir: "/skills/repo-helper-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: newAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-new",
|
||||
SkillDir: "/skills/repo-helper-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
got := skillsFromParts(msgs)
|
||||
require.Equal(t, []chattool.SkillMeta{
|
||||
{Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"},
|
||||
{Name: "repo-helper-new", Dir: "/skills/repo-helper-new"},
|
||||
}, got)
|
||||
}
|
||||
|
||||
func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-old",
|
||||
SkillDir: "/skills/repo-helper-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: newAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-new",
|
||||
SkillDir: "/skills/repo-helper-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
got := skillsFromParts(msgs)
|
||||
require.Equal(t, []chattool.SkillMeta{{
|
||||
Name: "repo-helper-new",
|
||||
Dir: "/skills/repo-helper-new",
|
||||
}}, got)
|
||||
}
|
||||
|
||||
func TestMergeSkillMetas(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
persisted := []chattool.SkillMeta{{
|
||||
Name: "repo-helper",
|
||||
Description: "Persisted skill",
|
||||
Dir: "/skills/repo-helper-old",
|
||||
}}
|
||||
discovered := []chattool.SkillMeta{
|
||||
{
|
||||
Name: "repo-helper",
|
||||
Description: "Discovered replacement",
|
||||
Dir: "/skills/repo-helper-new",
|
||||
MetaFile: "SKILL.md",
|
||||
},
|
||||
{
|
||||
Name: "deep-review",
|
||||
Description: "Discovered skill",
|
||||
Dir: "/skills/deep-review",
|
||||
},
|
||||
}
|
||||
|
||||
got := mergeSkillMetas(persisted, discovered)
|
||||
require.Equal(t, []chattool.SkillMeta{
|
||||
discovered[0],
|
||||
discovered[1],
|
||||
}, got)
|
||||
}
|
||||
|
||||
func TestSelectSkillMetasForInstructionRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}}
|
||||
discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}}
|
||||
currentAgentID := uuid.New()
|
||||
otherAgentID := uuid.New()
|
||||
|
||||
t.Run("MergesCurrentAgentSkills", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := selectSkillMetasForInstructionRefresh(
|
||||
persisted,
|
||||
discovered,
|
||||
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
)
|
||||
require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got)
|
||||
})
|
||||
|
||||
t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := selectSkillMetasForInstructionRefresh(
|
||||
persisted,
|
||||
discovered,
|
||||
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
||||
)
|
||||
require.Equal(t, discovered, got)
|
||||
})
|
||||
|
||||
t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := selectSkillMetasForInstructionRefresh(
|
||||
persisted,
|
||||
nil,
|
||||
uuid.NullUUID{},
|
||||
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
||||
)
|
||||
require.Equal(t, persisted, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveChainModeIgnoresSkillOnlySentinelMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelConfigID := uuid.New()
|
||||
assistant := database.ChatMessage{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper",
|
||||
SkillDir: "/skills/repo-helper",
|
||||
},
|
||||
})
|
||||
skillOnly.Role = database.ChatMessageRoleUser
|
||||
user := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "latest user message",
|
||||
}})
|
||||
user.Role = database.ChatMessageRoleUser
|
||||
|
||||
got := resolveChainMode([]database.ChatMessage{assistant, skillOnly, user})
|
||||
require.Equal(t, "resp-123", got.previousResponseID)
|
||||
require.Equal(t, modelConfigID, got.modelConfigID)
|
||||
require.Equal(t, 2, got.trailingUserCount)
|
||||
require.Equal(t, 1, got.contributingTrailingUserCount)
|
||||
}
|
||||
|
||||
func TestFilterPromptForChainModeKeepsContributingUsersAcrossSkippedSentinelTurns(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelConfigID := uuid.New()
|
||||
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "prior user message",
|
||||
}})
|
||||
priorUser.Role = database.ChatMessageRoleUser
|
||||
assistant := database.ChatMessage{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
firstTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "first trailing user",
|
||||
}})
|
||||
firstTrailingUser.Role = database.ChatMessageRoleUser
|
||||
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper",
|
||||
SkillDir: "/skills/repo-helper",
|
||||
},
|
||||
})
|
||||
skillOnly.Role = database.ChatMessageRoleUser
|
||||
lastTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "last trailing user",
|
||||
}})
|
||||
lastTrailingUser.Role = database.ChatMessageRoleUser
|
||||
|
||||
chainInfo := resolveChainMode([]database.ChatMessage{
|
||||
priorUser,
|
||||
assistant,
|
||||
firstTrailingUser,
|
||||
skillOnly,
|
||||
lastTrailingUser,
|
||||
})
|
||||
require.Equal(t, 3, chainInfo.trailingUserCount)
|
||||
require.Equal(t, 2, chainInfo.contributingTrailingUserCount)
|
||||
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "system instruction"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "prior user message"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "assistant reply"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "first trailing user"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "last trailing user"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := filterPromptForChainMode(prompt, chainInfo)
|
||||
require.Len(t, got, 3)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
|
||||
|
||||
firstPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "first trailing user", firstPart.Text)
|
||||
lastPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[2].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "last trailing user", lastPart.Text)
|
||||
}
|
||||
|
||||
func TestFilterPromptForChainModeUsesContributingTrailingUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelConfigID := uuid.New()
|
||||
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "prior user message",
|
||||
}})
|
||||
priorUser.Role = database.ChatMessageRoleUser
|
||||
assistant := database.ChatMessage{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper",
|
||||
SkillDir: "/skills/repo-helper",
|
||||
},
|
||||
})
|
||||
skillOnly.Role = database.ChatMessageRoleUser
|
||||
latestUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "latest user message",
|
||||
}})
|
||||
latestUser.Role = database.ChatMessageRoleUser
|
||||
|
||||
chainInfo := resolveChainMode([]database.ChatMessage{
|
||||
priorUser,
|
||||
assistant,
|
||||
skillOnly,
|
||||
latestUser,
|
||||
})
|
||||
require.Equal(t, 2, chainInfo.trailingUserCount)
|
||||
require.Equal(t, 1, chainInfo.contributingTrailingUserCount)
|
||||
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "system instruction"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "prior user message"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "assistant reply"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "latest user message"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := filterPromptForChainMode(prompt, chainInfo)
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
|
||||
|
||||
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "latest user message", part.Text)
|
||||
}
|
||||
|
||||
func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage {
|
||||
raw, _ := json.Marshal(parts)
|
||||
return database.ChatMessage{
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -100,6 +101,10 @@ type RunOptions struct {
|
||||
// first stream part before the attempt is canceled and
|
||||
// retried. Zero uses the production default.
|
||||
StartupTimeout time.Duration
|
||||
// Clock creates startup guard timers. In production use a
|
||||
// real clock; tests can inject quartz.NewMock(t) to make
|
||||
// startup timeout behavior deterministic.
|
||||
Clock quartz.Clock
|
||||
|
||||
ActiveTools []string
|
||||
ContextLimitFallback int64
|
||||
@@ -289,6 +294,9 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
if opts.StartupTimeout <= 0 {
|
||||
opts.StartupTimeout = defaultStartupTimeout
|
||||
}
|
||||
if opts.Clock == nil {
|
||||
opts.Clock = quartz.NewReal()
|
||||
}
|
||||
|
||||
publishMessagePart := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) {
|
||||
if opts.PublishMessagePart == nil {
|
||||
@@ -364,6 +372,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
attempt, streamErr := guardedStream(
|
||||
retryCtx,
|
||||
opts.Model.Provider(),
|
||||
opts.Clock,
|
||||
opts.StartupTimeout,
|
||||
func(attemptCtx context.Context) (fantasy.StreamResponse, error) {
|
||||
return opts.Model.Stream(attemptCtx, call)
|
||||
@@ -660,17 +669,18 @@ type guardedAttempt struct {
|
||||
// stream startup. Exactly one outcome wins: the timer cancels
|
||||
// the attempt, or the first-part path disarms the timer.
|
||||
type startupGuard struct {
|
||||
timer *time.Timer
|
||||
timer *quartz.Timer
|
||||
cancel context.CancelCauseFunc
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newStartupGuard(
|
||||
clock quartz.Clock,
|
||||
timeout time.Duration,
|
||||
cancel context.CancelCauseFunc,
|
||||
) *startupGuard {
|
||||
guard := &startupGuard{cancel: cancel}
|
||||
guard.timer = time.AfterFunc(timeout, guard.onTimeout)
|
||||
guard.timer = clock.AfterFunc(timeout, guard.onTimeout, "startupGuard")
|
||||
return guard
|
||||
}
|
||||
|
||||
@@ -707,11 +717,12 @@ func classifyStartupTimeout(
|
||||
func guardedStream(
|
||||
parent context.Context,
|
||||
provider string,
|
||||
clock quartz.Clock,
|
||||
timeout time.Duration,
|
||||
openStream func(context.Context) (fantasy.StreamResponse, error),
|
||||
) (guardedAttempt, error) {
|
||||
attemptCtx, cancelAttempt := context.WithCancelCause(parent)
|
||||
guard := newStartupGuard(timeout, cancelAttempt)
|
||||
guard := newStartupGuard(clock, timeout, cancelAttempt)
|
||||
var releaseOnce sync.Once
|
||||
release := func() {
|
||||
releaseOnce.Do(func() {
|
||||
|
||||
@@ -19,10 +19,24 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const activeToolName = "read_file"
|
||||
|
||||
func awaitRunResult(ctx context.Context, t *testing.T, done <-chan error) error {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for Run to complete")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -202,7 +216,7 @@ func TestStartupGuard_DisarmAndFireRace(t *testing.T) {
|
||||
|
||||
for range 128 {
|
||||
var cancels atomic.Int32
|
||||
guard := newStartupGuard(time.Hour, func(err error) {
|
||||
guard := newStartupGuard(quartz.NewReal(), time.Hour, func(err error) {
|
||||
if errors.Is(err, errStartupTimeout) {
|
||||
cancels.Add(1)
|
||||
}
|
||||
@@ -240,7 +254,7 @@ func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) {
|
||||
attemptCtx, cancelAttempt := context.WithCancelCause(context.Background())
|
||||
defer cancelAttempt(nil)
|
||||
|
||||
guard := newStartupGuard(time.Hour, cancelAttempt)
|
||||
guard := newStartupGuard(quartz.NewReal(), time.Hour, cancelAttempt)
|
||||
guard.Disarm()
|
||||
guard.onTimeout()
|
||||
|
||||
@@ -259,6 +273,16 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitShort,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
@@ -278,23 +302,32 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 2, attempts)
|
||||
require.Len(t, retries, 1)
|
||||
require.Equal(t, chaterror.KindStartupTimeout, retries[0].Kind)
|
||||
@@ -305,7 +338,12 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
"OpenAI did not start responding in time.",
|
||||
retries[0].Message,
|
||||
)
|
||||
require.ErrorIs(t, <-attemptCause, errStartupTimeout)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
@@ -313,6 +351,16 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitShort,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
@@ -337,23 +385,32 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 2, attempts)
|
||||
require.Len(t, retries, 1)
|
||||
require.Equal(t, chaterror.KindStartupTimeout, retries[0].Kind)
|
||||
@@ -364,7 +421,12 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
"OpenAI did not start responding in time.",
|
||||
retries[0].Message,
|
||||
)
|
||||
require.ErrorIs(t, <-attemptCause, errStartupTimeout)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
@@ -372,8 +434,19 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitShort,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
|
||||
attempts := 0
|
||||
retried := false
|
||||
firstPartYielded := make(chan struct{}, 1)
|
||||
continueStream := make(chan struct{})
|
||||
model := &loopTestModel{
|
||||
provider: "openai",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
@@ -382,18 +455,19 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) {
|
||||
return
|
||||
}
|
||||
|
||||
timer := time.NewTimer(startupTimeout * 2)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case firstPartYielded <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-continueStream:
|
||||
case <-ctx.Done():
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
Error: ctx.Err(),
|
||||
})
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
parts := []fantasy.StreamPart{
|
||||
@@ -410,23 +484,40 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
_ chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retried = true
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
_ chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retried = true
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
select {
|
||||
case <-firstPartYielded:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for first stream part")
|
||||
}
|
||||
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
close(continueStream)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 1, attempts)
|
||||
require.False(t, retried)
|
||||
}
|
||||
@@ -479,6 +570,16 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitShort,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
@@ -499,23 +600,32 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 2, attempts)
|
||||
require.Len(t, retries, 1)
|
||||
require.Equal(t, chaterror.KindStartupTimeout, retries[0].Kind)
|
||||
@@ -526,7 +636,12 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
"OpenAI did not start responding in time.",
|
||||
retries[0].Message,
|
||||
)
|
||||
require.ErrorIs(t, <-attemptCause, errStartupTimeout)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// AgentChatContextSentinelPath marks the synthetic empty context-file
|
||||
// part used to preserve skill-only workspace-agent additions across
|
||||
// turns without treating them as persisted instruction files.
|
||||
const AgentChatContextSentinelPath = ".coder/agent-chat-context-sentinel"
|
||||
|
||||
// FilterContextParts keeps only context-file and skill parts from parts.
|
||||
// When keepEmptyContextFiles is false, context-file parts with empty
|
||||
// content are dropped. When keepEmptyContextFiles is true, empty
|
||||
// context-file parts are preserved.
|
||||
// revive:disable-next-line:flag-parameter // Required by shared helper callers.
|
||||
func FilterContextParts(
|
||||
parts []codersdk.ChatMessagePart,
|
||||
keepEmptyContextFiles bool,
|
||||
) []codersdk.ChatMessagePart {
|
||||
var filtered []codersdk.ChatMessagePart
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
if !keepEmptyContextFiles && part.ContextFileContent == "" {
|
||||
continue
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
default:
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, part)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// CollectContextPartsFromMessages unmarshals chat message content and
|
||||
// collects the context-file and skill parts it contains. When
|
||||
// keepEmptyContextFiles is false, empty context-file parts are skipped.
|
||||
// When it is true, empty context-file parts are included in the result.
|
||||
func CollectContextPartsFromMessages(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
messages []database.ChatMessage,
|
||||
keepEmptyContextFiles bool,
|
||||
) ([]codersdk.ChatMessagePart, error) {
|
||||
var collected []codersdk.ChatMessagePart
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
|
||||
logger.Warn(ctx, "skipping malformed chat context message",
|
||||
slog.F("chat_message_id", msg.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
collected = append(
|
||||
collected,
|
||||
FilterContextParts(parts, keepEmptyContextFiles)...,
|
||||
)
|
||||
}
|
||||
|
||||
return collected, nil
|
||||
}
|
||||
|
||||
func latestContextAgentIDFromParts(parts []codersdk.ChatMessagePart) (uuid.UUID, bool) {
|
||||
var lastID uuid.UUID
|
||||
found := false
|
||||
for _, part := range parts {
|
||||
if part.Type != codersdk.ChatMessagePartTypeContextFile ||
|
||||
!part.ContextFileAgentID.Valid {
|
||||
continue
|
||||
}
|
||||
lastID = part.ContextFileAgentID.UUID
|
||||
found = true
|
||||
}
|
||||
return lastID, found
|
||||
}
|
||||
|
||||
// FilterContextPartsToLatestAgent keeps parts stamped with the latest
|
||||
// workspace-agent ID seen in the slice, plus legacy unstamped parts.
|
||||
// When no stamped context-file parts exist, it returns the original
|
||||
// slice unchanged.
|
||||
func FilterContextPartsToLatestAgent(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart {
|
||||
latestAgentID, ok := latestContextAgentIDFromParts(parts)
|
||||
if !ok {
|
||||
return parts
|
||||
}
|
||||
|
||||
filtered := make([]codersdk.ChatMessagePart, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile,
|
||||
codersdk.ChatMessagePartTypeSkill:
|
||||
if part.ContextFileAgentID.Valid &&
|
||||
part.ContextFileAgentID.UUID != latestAgentID {
|
||||
continue
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, part)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// BuildLastInjectedContext filters parts down to non-empty context-file
|
||||
// and skill parts, strips their internal fields, and marshals the
|
||||
// result for LastInjectedContext. A nil or fully filtered input returns
|
||||
// an invalid NullRawMessage.
|
||||
func BuildLastInjectedContext(
|
||||
parts []codersdk.ChatMessagePart,
|
||||
) (pqtype.NullRawMessage, error) {
|
||||
if parts == nil {
|
||||
return pqtype.NullRawMessage{Valid: false}, nil
|
||||
}
|
||||
|
||||
filtered := FilterContextParts(parts, false)
|
||||
if len(filtered) == 0 {
|
||||
return pqtype.NullRawMessage{Valid: false}, nil
|
||||
}
|
||||
|
||||
stripped := make([]codersdk.ChatMessagePart, 0, len(filtered))
|
||||
for _, part := range filtered {
|
||||
cp := part
|
||||
cp.StripInternal()
|
||||
stripped = append(stripped, cp)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(stripped)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf(
|
||||
"marshal injected context: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return pqtype.NullRawMessage{RawMessage: raw, Valid: true}, nil
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -57,6 +59,34 @@ func formatSystemInstructions(
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// latestContextAgentID returns the most recent workspace-agent ID seen
|
||||
// on any persisted context-file part, including the skill-only sentinel.
|
||||
// Returns uuid.Nil, false when no stamped context-file parts exist.
|
||||
func latestContextAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
||||
var lastID uuid.UUID
|
||||
found := false
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid ||
|
||||
!bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) {
|
||||
continue
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type != codersdk.ChatMessagePartTypeContextFile ||
|
||||
!part.ContextFileAgentID.Valid {
|
||||
continue
|
||||
}
|
||||
lastID = part.ContextFileAgentID.UUID
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return lastID, found
|
||||
}
|
||||
|
||||
// instructionFromContextFiles reconstructs the formatted instruction
|
||||
// string from persisted context-file parts. This is used on non-first
|
||||
// turns so the instruction can be re-injected after compaction
|
||||
@@ -64,6 +94,7 @@ func formatSystemInstructions(
|
||||
func instructionFromContextFiles(
|
||||
messages []database.ChatMessage,
|
||||
) string {
|
||||
filterAgentID, filterByAgent := latestContextAgentID(messages)
|
||||
var contextParts []codersdk.ChatMessagePart
|
||||
var os, dir string
|
||||
for _, msg := range messages {
|
||||
@@ -79,6 +110,10 @@ func instructionFromContextFiles(
|
||||
if part.Type != codersdk.ChatMessagePartTypeContextFile {
|
||||
continue
|
||||
}
|
||||
if filterByAgent && part.ContextFileAgentID.Valid &&
|
||||
part.ContextFileAgentID.UUID != filterAgentID {
|
||||
continue
|
||||
}
|
||||
if part.ContextFileOS != "" {
|
||||
os = part.ContextFileOS
|
||||
}
|
||||
@@ -93,6 +128,80 @@ func instructionFromContextFiles(
|
||||
return formatSystemInstructions(os, dir, contextParts)
|
||||
}
|
||||
|
||||
// hasPersistedInstructionFiles reports whether messages include a
|
||||
// persisted context-file part that should suppress another baseline
|
||||
// instruction-file lookup. The workspace-agent skill-only sentinel is
|
||||
// ignored so default instructions still load on fresh chats.
|
||||
func hasPersistedInstructionFiles(
|
||||
messages []database.ChatMessage,
|
||||
) bool {
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid ||
|
||||
!bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) {
|
||||
continue
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type != codersdk.ChatMessagePartTypeContextFile ||
|
||||
!part.ContextFileAgentID.Valid ||
|
||||
part.ContextFilePath == AgentChatContextSentinelPath {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mergeSkillMetas(
|
||||
persisted []chattool.SkillMeta,
|
||||
discovered []chattool.SkillMeta,
|
||||
) []chattool.SkillMeta {
|
||||
if len(persisted) == 0 {
|
||||
return discovered
|
||||
}
|
||||
if len(discovered) == 0 {
|
||||
return persisted
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(persisted)+len(discovered))
|
||||
merged := make([]chattool.SkillMeta, 0, len(persisted)+len(discovered))
|
||||
appendUnique := func(skill chattool.SkillMeta) {
|
||||
if _, ok := seen[skill.Name]; ok {
|
||||
return
|
||||
}
|
||||
seen[skill.Name] = struct{}{}
|
||||
merged = append(merged, skill)
|
||||
}
|
||||
for _, skill := range discovered {
|
||||
appendUnique(skill)
|
||||
}
|
||||
for _, skill := range persisted {
|
||||
appendUnique(skill)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// selectSkillMetasForInstructionRefresh chooses which skill metadata
|
||||
// should be injected on a turn that refreshes instruction files.
|
||||
func selectSkillMetasForInstructionRefresh(
|
||||
persisted []chattool.SkillMeta,
|
||||
discovered []chattool.SkillMeta,
|
||||
currentAgentID uuid.NullUUID,
|
||||
latestInjectedAgentID uuid.NullUUID,
|
||||
) []chattool.SkillMeta {
|
||||
if currentAgentID.Valid && latestInjectedAgentID.Valid && latestInjectedAgentID.UUID == currentAgentID.UUID {
|
||||
return mergeSkillMetas(persisted, discovered)
|
||||
}
|
||||
if !currentAgentID.Valid && len(discovered) == 0 {
|
||||
return persisted
|
||||
}
|
||||
return discovered
|
||||
}
|
||||
|
||||
// skillsFromParts reconstructs skill metadata from persisted
|
||||
// skill parts. This is analogous to instructionFromContextFiles
|
||||
// so the skill index can be re-injected after compaction without
|
||||
@@ -100,6 +209,7 @@ func instructionFromContextFiles(
|
||||
func skillsFromParts(
|
||||
messages []database.ChatMessage,
|
||||
) []chattool.SkillMeta {
|
||||
filterAgentID, filterByAgent := latestContextAgentID(messages)
|
||||
var skills []chattool.SkillMeta
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid ||
|
||||
@@ -114,6 +224,10 @@ func skillsFromParts(
|
||||
if part.Type != codersdk.ChatMessagePartTypeSkill {
|
||||
continue
|
||||
}
|
||||
if filterByAgent && part.ContextFileAgentID.Valid &&
|
||||
part.ContextFileAgentID.UUID != filterAgentID {
|
||||
continue
|
||||
}
|
||||
skills = append(skills, chattool.SkillMeta{
|
||||
Name: part.SkillName,
|
||||
Description: part.SkillDescription,
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
@@ -160,7 +159,7 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
}
|
||||
chat.Title = title
|
||||
generatedTitle.Store(title)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindTitleChange, nil)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+148
-54
@@ -2,8 +2,11 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
@@ -13,71 +16,60 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type recordingResult struct {
|
||||
recordingFileID string
|
||||
thumbnailFileID string
|
||||
}
|
||||
|
||||
// stopAndStoreRecording stops the desktop recording, downloads the
|
||||
// MP4, and stores it in chat_files. Only called when the subagent
|
||||
// completed successfully. Returns the file ID on success, empty
|
||||
// string on any failure. All errors are logged but not propagated
|
||||
// — recording is best-effort.
|
||||
// multipart response containing the MP4 and optional thumbnail, and
|
||||
// stores them in chat_files. Only called when the subagent completed
|
||||
// successfully. Returns file IDs on success, empty fields on any
|
||||
// failure. All errors are logged but not propagated; recording is
|
||||
// best-effort.
|
||||
func (p *Server) stopAndStoreRecording(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
recordingID string,
|
||||
ownerID uuid.UUID,
|
||||
workspaceID uuid.NullUUID,
|
||||
) string {
|
||||
) recordingResult {
|
||||
var result recordingResult
|
||||
|
||||
select {
|
||||
case p.recordingSem <- struct{}{}:
|
||||
defer func() { <-p.recordingSem }()
|
||||
case <-ctx.Done():
|
||||
p.logger.Warn(ctx, "context canceled waiting for recording semaphore", slog.Error(ctx.Err()))
|
||||
return ""
|
||||
return result
|
||||
}
|
||||
|
||||
body, err := conn.StopDesktopRecording(ctx,
|
||||
resp, err := conn.StopDesktopRecording(ctx,
|
||||
workspacesdk.StopDesktopRecordingRequest{RecordingID: recordingID})
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to stop desktop recording",
|
||||
slog.Error(err))
|
||||
return ""
|
||||
return result
|
||||
}
|
||||
type readResult struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
ch := make(chan readResult, 1)
|
||||
go func() {
|
||||
data, err := io.ReadAll(io.LimitReader(body, workspacesdk.MaxRecordingSize+1))
|
||||
ch <- readResult{data, err}
|
||||
}()
|
||||
defer resp.Body.Close()
|
||||
|
||||
var data []byte
|
||||
select {
|
||||
case res := <-ch:
|
||||
body.Close()
|
||||
data = res.data
|
||||
if res.err != nil {
|
||||
p.logger.Warn(ctx, "failed to read recording data", slog.Error(res.err))
|
||||
return ""
|
||||
}
|
||||
case <-ctx.Done():
|
||||
body.Close()
|
||||
p.logger.Warn(ctx, "context canceled while reading recording data", slog.Error(ctx.Err()))
|
||||
return ""
|
||||
_, params, err := mime.ParseMediaType(resp.ContentType)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to parse content type from recording response",
|
||||
slog.F("content_type", resp.ContentType),
|
||||
slog.Error(err))
|
||||
return result
|
||||
}
|
||||
if len(data) > workspacesdk.MaxRecordingSize {
|
||||
p.logger.Warn(ctx, "recording data exceeds maximum size, skipping store",
|
||||
slog.F("size", len(data)),
|
||||
slog.F("max_size", workspacesdk.MaxRecordingSize))
|
||||
return ""
|
||||
}
|
||||
if len(data) == 0 {
|
||||
p.logger.Warn(ctx, "recording data is empty, skipping store")
|
||||
return ""
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
p.logger.Warn(ctx, "missing boundary in recording response content type",
|
||||
slog.F("content_type", resp.ContentType))
|
||||
return result
|
||||
}
|
||||
|
||||
if !workspaceID.Valid {
|
||||
p.logger.Warn(ctx, "chat has no workspace, cannot store recording")
|
||||
return ""
|
||||
return result
|
||||
}
|
||||
|
||||
// The chatd actor is used here because the recording is stored on
|
||||
@@ -87,21 +79,123 @@ func (p *Server) stopAndStoreRecording(
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to resolve workspace for recording",
|
||||
slog.Error(err))
|
||||
return ""
|
||||
return result
|
||||
}
|
||||
|
||||
//nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline.
|
||||
row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
Name: fmt.Sprintf("recording-%s.mp4", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")),
|
||||
Mimetype: "video/mp4",
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to store recording in database",
|
||||
slog.Error(err))
|
||||
return ""
|
||||
mr := multipart.NewReader(resp.Body, boundary)
|
||||
// Context cancellation is checked between parts. Within a
|
||||
// part read, cancellation relies on Go's HTTP transport closing
|
||||
// the underlying connection when the context is done, which
|
||||
// interrupts the blocked io.ReadAll.
|
||||
// First pass: parse all multipart parts into memory.
|
||||
// The agent sends at most two parts: one video/mp4 and one
|
||||
// optional image/jpeg thumbnail. Cap the number of parts to
|
||||
// prevent a malicious or broken agent from forcing the server
|
||||
// into an unbounded parsing loop.
|
||||
const maxParts = 2
|
||||
var videoData, thumbnailData []byte
|
||||
for range maxParts {
|
||||
if ctx.Err() != nil {
|
||||
p.logger.Warn(ctx, "context canceled while reading recording parts", slog.Error(ctx.Err()))
|
||||
break
|
||||
}
|
||||
|
||||
part, err := mr.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "error reading next multipart part", slog.Error(err))
|
||||
break
|
||||
}
|
||||
|
||||
contentType := part.Header.Get("Content-Type")
|
||||
|
||||
// Select the read limit based on content type so that
|
||||
// thumbnails (image/jpeg) do not allocate up to
|
||||
// MaxRecordingSize (100 MB) before the size check rejects
|
||||
// them. Unknown types use a small default since they are
|
||||
// discarded below.
|
||||
maxSize := int64(1 << 20) // 1 MB default for unknown types
|
||||
switch contentType {
|
||||
case "video/mp4":
|
||||
maxSize = int64(workspacesdk.MaxRecordingSize)
|
||||
case "image/jpeg":
|
||||
maxSize = int64(workspacesdk.MaxThumbnailSize)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(part, maxSize+1))
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to read recording part data",
|
||||
slog.F("content_type", contentType),
|
||||
slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if int64(len(data)) > maxSize {
|
||||
p.logger.Warn(ctx, "recording part exceeds maximum size, skipping",
|
||||
slog.F("content_type", contentType),
|
||||
slog.F("size", len(data)),
|
||||
slog.F("max_size", maxSize))
|
||||
continue
|
||||
}
|
||||
if len(data) == 0 {
|
||||
p.logger.Warn(ctx, "recording part is empty, skipping",
|
||||
slog.F("content_type", contentType))
|
||||
continue
|
||||
}
|
||||
|
||||
switch contentType {
|
||||
case "video/mp4":
|
||||
if videoData != nil {
|
||||
p.logger.Warn(ctx, "duplicate video/mp4 part in recording response, skipping")
|
||||
continue
|
||||
}
|
||||
videoData = data
|
||||
case "image/jpeg":
|
||||
if thumbnailData != nil {
|
||||
p.logger.Warn(ctx, "duplicate image/jpeg part in recording response, skipping")
|
||||
continue
|
||||
}
|
||||
thumbnailData = data
|
||||
default:
|
||||
p.logger.Debug(ctx, "skipping unknown part content type",
|
||||
slog.F("content_type", contentType))
|
||||
}
|
||||
}
|
||||
return row.ID.String()
|
||||
|
||||
// Second pass: store the collected data in the database.
|
||||
if videoData != nil {
|
||||
//nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline.
|
||||
row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
Name: fmt.Sprintf("recording-%s.mp4", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")),
|
||||
Mimetype: "video/mp4",
|
||||
Data: videoData,
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to store recording in database",
|
||||
slog.Error(err))
|
||||
} else {
|
||||
result.recordingFileID = row.ID.String()
|
||||
}
|
||||
}
|
||||
if thumbnailData != nil && result.recordingFileID != "" {
|
||||
//nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline.
|
||||
row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
Name: fmt.Sprintf("thumbnail-%s.jpg", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")),
|
||||
Mimetype: "image/jpeg",
|
||||
Data: thumbnailData,
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to store thumbnail in database",
|
||||
slog.Error(err))
|
||||
} else {
|
||||
result.thumbnailFileID = row.ID.String()
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -34,6 +36,30 @@ func (zeroReader) Read(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// partSpec describes a single part for buildMultipartResponse.
|
||||
type partSpec struct {
|
||||
contentType string
|
||||
data []byte
|
||||
}
|
||||
|
||||
// buildMultipartResponse constructs a StopDesktopRecordingResponse
|
||||
// with the given content type/data pairs encoded as multipart/mixed.
|
||||
func buildMultipartResponse(parts ...partSpec) workspacesdk.StopDesktopRecordingResponse {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
for _, p := range parts {
|
||||
partWriter, _ := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {p.contentType},
|
||||
})
|
||||
_, _ = partWriter.Write(p.data)
|
||||
}
|
||||
_ = mw.Close()
|
||||
return workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: io.NopCloser(bytes.NewReader(buf.Bytes())),
|
||||
ContentType: "multipart/mixed; boundary=" + mw.Boundary(),
|
||||
}
|
||||
}
|
||||
|
||||
// createComputerUseParentChild creates a parent chat and a
|
||||
// computer_use child chat bound to the given workspace/agent.
|
||||
// Both chats are inserted directly via DB to avoid triggering
|
||||
@@ -170,8 +196,7 @@ func TestWaitAgentComputerUseRecording(t *testing.T) {
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(io.NopCloser(bytes.NewReader(fakeMp4)), nil).
|
||||
Times(1)
|
||||
Return(buildMultipartResponse(partSpec{"video/mp4", fakeMp4}), nil).Times(1)
|
||||
|
||||
// Invoke wait_agent via the tool closure.
|
||||
resp, err := invokeWaitAgentTool(ctx, t, server, db, parent.ID, child.ID, 5)
|
||||
@@ -198,6 +223,87 @@ func TestWaitAgentComputerUseRecording(t *testing.T) {
|
||||
assert.Equal(t, fakeMp4, chatFile.Data)
|
||||
}
|
||||
|
||||
// TestWaitAgentComputerUseRecordingWithThumbnail verifies the
|
||||
// recording flow when the agent produces both video and thumbnail:
|
||||
// both file IDs appear in the wait_agent tool response.
|
||||
func TestWaitAgentComputerUseRecordingWithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, agent := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
parent, child := createComputerUseParentChild(
|
||||
ctx, t, server, user, model, workspace, agent,
|
||||
"parent-recording-thumb", "computer-use-child-thumb",
|
||||
)
|
||||
|
||||
server.drainInflight()
|
||||
|
||||
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
require.Equal(t, agent.ID, agentID)
|
||||
return mockConn, func() {}, nil
|
||||
}
|
||||
|
||||
insertAssistantMessage(ctx, t, db, child.ID, model.ID, "I opened Firefox and took a screenshot.")
|
||||
|
||||
setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "")
|
||||
|
||||
fakeMp4 := []byte("fake-mp4-data-with-thumbnail-test")
|
||||
fakeThumb := []byte("fake-jpeg-thumbnail-data")
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, req workspacesdk.StartDesktopRecordingRequest) error {
|
||||
require.NotEmpty(t, req.RecordingID)
|
||||
return nil
|
||||
}).
|
||||
Times(1)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(
|
||||
partSpec{"video/mp4", fakeMp4},
|
||||
partSpec{"image/jpeg", fakeThumb},
|
||||
), nil).Times(1)
|
||||
|
||||
resp, err := invokeWaitAgentTool(ctx, t, server, db, parent.ID, child.ID, 5)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError, "expected successful response, got: %s", resp.Content)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
|
||||
// Verify recording_file_id is present and valid.
|
||||
storedFileID, ok := result["recording_file_id"].(string)
|
||||
require.True(t, ok, "recording_file_id must be present in response")
|
||||
require.NotEmpty(t, storedFileID)
|
||||
fileUUID, err := uuid.Parse(storedFileID)
|
||||
require.NoError(t, err)
|
||||
chatFile, err := db.GetChatFileByID(ctx, fileUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", chatFile.Mimetype)
|
||||
assert.Equal(t, fakeMp4, chatFile.Data)
|
||||
|
||||
// Verify thumbnail_file_id is present and valid.
|
||||
thumbFileID, ok := result["thumbnail_file_id"].(string)
|
||||
require.True(t, ok, "thumbnail_file_id must be present in response")
|
||||
require.NotEmpty(t, thumbFileID)
|
||||
thumbUUID, err := uuid.Parse(thumbFileID)
|
||||
require.NoError(t, err)
|
||||
thumbFile, err := db.GetChatFileByID(ctx, thumbUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image/jpeg", thumbFile.Mimetype)
|
||||
assert.Equal(t, fakeThumb, thumbFile.Data)
|
||||
}
|
||||
|
||||
// TestWaitAgentNonComputerUseNoRecording verifies that when the
|
||||
// child chat is NOT a computer_use chat, no recording is attempted.
|
||||
// StartDesktopRecording must never be called.
|
||||
@@ -342,7 +448,7 @@ func TestWaitAgentRecordingStopFails(t *testing.T) {
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(nil, xerrors.New("disk full")).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{}, xerrors.New("disk full")).
|
||||
Times(1)
|
||||
|
||||
// Invoke wait_agent via the tool closure.
|
||||
@@ -446,10 +552,10 @@ func TestWaitAgentTimeoutLeavesRecordingRunning(t *testing.T) {
|
||||
assert.Contains(t, result.resp.Content, "timed out")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecordingOversized verifies that when the recording
|
||||
// data exceeds MaxRecordingSize, stopAndStoreRecording returns an
|
||||
// empty string and does NOT call InsertChatFile.
|
||||
func TestStopAndStoreRecordingOversized(t *testing.T) {
|
||||
// TestStopAndStoreRecording_Oversized verifies that when the
|
||||
// recording data exceeds MaxRecordingSize, stopAndStoreRecording
|
||||
// returns an empty string and does NOT call InsertChatFile.
|
||||
func TestStopAndStoreRecording_Oversized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -463,29 +569,146 @@ func TestStopAndStoreRecordingOversized(t *testing.T) {
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
// Create a reader that produces MaxRecordingSize+1 bytes without
|
||||
// allocating the full buffer in memory.
|
||||
oversizedReader := io.LimitReader(
|
||||
&zeroReader{},
|
||||
int64(workspacesdk.MaxRecordingSize+1),
|
||||
)
|
||||
// Build a streaming multipart response with a video/mp4 part
|
||||
// that exceeds MaxRecordingSize without allocating the full
|
||||
// buffer in memory.
|
||||
pr, pw := io.Pipe()
|
||||
mw := multipart.NewWriter(pw)
|
||||
go func() {
|
||||
partWriter, _ := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
// Stream MaxRecordingSize+1 zero bytes.
|
||||
_, _ = io.Copy(partWriter, io.LimitReader(&zeroReader{}, int64(workspacesdk.MaxRecordingSize+1)))
|
||||
_ = mw.Close()
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(io.NopCloser(oversizedReader), nil).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: pr,
|
||||
ContentType: "multipart/mixed; boundary=" + mw.Boundary(),
|
||||
}, nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
storedFileID := server.stopAndStoreRecording(
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
assert.Empty(t, storedFileID, "oversized recording should not be stored")
|
||||
assert.Empty(t, result.recordingFileID, "oversized recording should not be stored")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecordingEmpty verifies that when the recording
|
||||
// TestStopAndStoreRecording_OversizedThumbnail verifies that when the
|
||||
// thumbnail part exceeds MaxThumbnailSize it is skipped while the
|
||||
// normal-sized video part is still stored.
|
||||
func TestStopAndStoreRecording_OversizedThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
videoData := bytes.Repeat([]byte{0xAA}, 1024)
|
||||
|
||||
// Build a streaming multipart response with a normal video part
|
||||
// and an oversized thumbnail part.
|
||||
pr, pw := io.Pipe()
|
||||
mw := multipart.NewWriter(pw)
|
||||
go func() {
|
||||
vw, _ := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
_, _ = vw.Write(videoData)
|
||||
tw, _ := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"image/jpeg"},
|
||||
})
|
||||
// Stream MaxThumbnailSize+1 zero bytes for the thumbnail.
|
||||
_, _ = io.Copy(tw, io.LimitReader(&zeroReader{}, int64(workspacesdk.MaxThumbnailSize+1)))
|
||||
_ = mw.Close()
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: pr,
|
||||
ContentType: "multipart/mixed; boundary=" + mw.Boundary(),
|
||||
}, nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Video should be stored.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err, "RecordingFileID should be a valid UUID")
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", recFile.Mimetype)
|
||||
assert.Equal(t, videoData, recFile.Data)
|
||||
|
||||
// Thumbnail should be skipped (oversized).
|
||||
assert.Empty(t, result.thumbnailFileID, "oversized thumbnail should not be stored")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_DuplicatePartsIgnored verifies that when
|
||||
// a multipart response contains two video/mp4 parts, only the first
|
||||
// is stored and the duplicate is skipped.
|
||||
func TestStopAndStoreRecording_DuplicatePartsIgnored(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
firstVideo := bytes.Repeat([]byte{0x01}, 512)
|
||||
secondVideo := bytes.Repeat([]byte{0x02}, 512)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(
|
||||
partSpec{"video/mp4", firstVideo},
|
||||
partSpec{"video/mp4", secondVideo},
|
||||
), nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Only the first video part should be stored.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err)
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, firstVideo, recFile.Data, "first video part should be stored, not the duplicate")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_Empty verifies that when the recording
|
||||
// data is empty, stopAndStoreRecording returns an empty string and
|
||||
// does NOT call InsertChatFile.
|
||||
func TestStopAndStoreRecordingEmpty(t *testing.T) {
|
||||
func TestStopAndStoreRecording_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -499,16 +722,265 @@ func TestStopAndStoreRecordingEmpty(t *testing.T) {
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
// Return empty data.
|
||||
// Build a multipart response with an empty video/mp4 part.
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(io.NopCloser(bytes.NewReader(nil)), nil).
|
||||
Times(1)
|
||||
Return(buildMultipartResponse(partSpec{"video/mp4", nil}), nil).Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
storedFileID := server.stopAndStoreRecording(
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
assert.Empty(t, storedFileID, "empty recording should not be stored")
|
||||
assert.Empty(t, result.recordingFileID, "empty recording should not be stored")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_WithThumbnail verifies that a multipart
|
||||
// response containing both a video/mp4 part and an image/jpeg part
|
||||
// results in both files being stored with correct mimetypes.
|
||||
func TestStopAndStoreRecording_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
videoData := bytes.Repeat([]byte{0xDE, 0xAD}, 512) // 1024 bytes
|
||||
thumbData := bytes.Repeat([]byte{0xFF, 0xD8}, 256) // 512 bytes
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(
|
||||
partSpec{"video/mp4", videoData},
|
||||
partSpec{"image/jpeg", thumbData},
|
||||
), nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Both file IDs should be valid UUIDs.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err, "RecordingFileID should be a valid UUID")
|
||||
|
||||
thumbUUID, err := uuid.Parse(result.thumbnailFileID)
|
||||
require.NoError(t, err, "ThumbnailFileID should be a valid UUID")
|
||||
// Verify the recording file in the database.
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", recFile.Mimetype)
|
||||
assert.Equal(t, videoData, recFile.Data)
|
||||
|
||||
// Verify the thumbnail file in the database.
|
||||
thumbFile, err := db.GetChatFileByID(ctx, thumbUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image/jpeg", thumbFile.Mimetype)
|
||||
assert.Equal(t, thumbData, thumbFile.Data)
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_VideoOnly verifies that a multipart
|
||||
// response with only a video/mp4 part stores the recording but
|
||||
// leaves thumbnailFileID empty.
|
||||
func TestStopAndStoreRecording_VideoOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
videoData := make([]byte, 1024)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(partSpec{"video/mp4", videoData}), nil).Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Recording should be stored.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err, "RecordingFileID should be a valid UUID")
|
||||
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", recFile.Mimetype)
|
||||
assert.Equal(t, videoData, recFile.Data)
|
||||
|
||||
// No thumbnail.
|
||||
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty when no thumbnail part is present")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_DownloadFailure verifies that when
|
||||
// StopDesktopRecording returns an error, stopAndStoreRecording
|
||||
// returns an empty recordingResult without panicking.
|
||||
func TestStopAndStoreRecording_DownloadFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{}, xerrors.New("network error")).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty on download failure")
|
||||
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty on download failure")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_UnknownPartIgnored verifies that parts
|
||||
// with unrecognized content types are silently skipped while known
|
||||
// parts (video/mp4 and image/jpeg) are still stored.
|
||||
func TestStopAndStoreRecording_UnknownPartIgnored(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
videoData := make([]byte, 1024)
|
||||
thumbData := make([]byte, 512)
|
||||
unknownData := make([]byte, 256)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(
|
||||
partSpec{"video/mp4", videoData},
|
||||
partSpec{"image/jpeg", thumbData},
|
||||
partSpec{"application/octet-stream", unknownData},
|
||||
), nil).Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Both known parts should be stored.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err, "RecordingFileID should be a valid UUID")
|
||||
|
||||
thumbUUID, err := uuid.Parse(result.thumbnailFileID)
|
||||
require.NoError(t, err, "ThumbnailFileID should be a valid UUID")
|
||||
|
||||
// Verify only 2 files exist (unknown part was skipped).
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", recFile.Mimetype)
|
||||
assert.Equal(t, videoData, recFile.Data)
|
||||
|
||||
thumbFile, err := db.GetChatFileByID(ctx, thumbUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image/jpeg", thumbFile.Mimetype)
|
||||
assert.Equal(t, thumbData, thumbFile.Data)
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_MalformedContentType verifies that a
|
||||
// response with an unparseable Content-Type returns an empty result.
|
||||
func TestStopAndStoreRecording_MalformedContentType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: io.NopCloser(bytes.NewReader(nil)),
|
||||
ContentType: "",
|
||||
}, nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty for malformed content type")
|
||||
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty for malformed content type")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_MissingBoundary verifies that a
|
||||
// multipart response without a boundary parameter returns an empty
|
||||
// result.
|
||||
func TestStopAndStoreRecording_MissingBoundary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: io.NopCloser(bytes.NewReader(nil)),
|
||||
ContentType: "multipart/mixed",
|
||||
}, nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty when boundary is missing")
|
||||
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty when boundary is missing")
|
||||
}
|
||||
|
||||
+275
-65
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
@@ -233,13 +234,13 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
}
|
||||
|
||||
// Only stop and store the recording on success.
|
||||
var storedFileID string
|
||||
var recResult recordingResult
|
||||
if recordingID != "" && agentConn != nil {
|
||||
// Use a fresh context for cleanup so a canceled
|
||||
// parent context doesn't prevent recording storage.
|
||||
stopCtx, stopCancel := context.WithTimeout(context.WithoutCancel(ctx), 90*time.Second)
|
||||
defer stopCancel()
|
||||
storedFileID = p.stopAndStoreRecording(stopCtx, agentConn,
|
||||
recResult = p.stopAndStoreRecording(stopCtx, agentConn,
|
||||
recordingID, parent.OwnerID, parent.WorkspaceID)
|
||||
}
|
||||
resp := map[string]any{
|
||||
@@ -248,8 +249,11 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
"report": report,
|
||||
"status": string(targetChat.Status),
|
||||
}
|
||||
if storedFileID != "" {
|
||||
resp["recording_file_id"] = storedFileID
|
||||
if recResult.recordingFileID != "" {
|
||||
resp["recording_file_id"] = recResult.recordingFileID
|
||||
}
|
||||
if recResult.thumbnailFileID != "" {
|
||||
resp["thumbnail_file_id"] = recResult.thumbnailFileID
|
||||
}
|
||||
return toolJSONResponse(resp), nil
|
||||
},
|
||||
@@ -358,48 +362,19 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(args.Prompt)
|
||||
if prompt == "" {
|
||||
return fantasy.NewTextErrorResponse("prompt is required"), nil
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(args.Title)
|
||||
if title == "" {
|
||||
title = subagentFallbackChatTitle(prompt)
|
||||
}
|
||||
|
||||
rootChatID := parent.ID
|
||||
if parent.RootChatID.Valid {
|
||||
rootChatID = parent.RootChatID.UUID
|
||||
}
|
||||
if parent.LastModelConfigID == uuid.Nil {
|
||||
return fantasy.NewTextErrorResponse("parent chat model config id is required"), nil
|
||||
}
|
||||
|
||||
// Create the child chat with Mode set to
|
||||
// computer_use. This signals runChat to use the
|
||||
// predefined computer use model and include the
|
||||
// computer tool.
|
||||
childChat, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
BuildID: parent.BuildID,
|
||||
AgentID: parent.AgentID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
childChat, err := p.createChildSubagentChatWithOptions(
|
||||
ctx,
|
||||
parent,
|
||||
args.Prompt,
|
||||
args.Title,
|
||||
childSubagentChatOptions{
|
||||
chatMode: database.NullChatMode{
|
||||
ChatMode: database.ChatModeComputerUse,
|
||||
Valid: true,
|
||||
},
|
||||
systemPrompt: computerUseSubagentSystemPrompt + "\n\n" + strings.TrimSpace(args.Prompt),
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: rootChatID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: computerUseSubagentSystemPrompt + "\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
MCPServerIDs: parent.MCPServerIDs,
|
||||
})
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
@@ -424,11 +399,26 @@ func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
|
||||
return chatID, nil
|
||||
}
|
||||
|
||||
type childSubagentChatOptions struct {
|
||||
chatMode database.NullChatMode
|
||||
systemPrompt string
|
||||
}
|
||||
|
||||
func (p *Server) createChildSubagentChat(
|
||||
ctx context.Context,
|
||||
parent database.Chat,
|
||||
prompt string,
|
||||
title string,
|
||||
) (database.Chat, error) {
|
||||
return p.createChildSubagentChatWithOptions(ctx, parent, prompt, title, childSubagentChatOptions{})
|
||||
}
|
||||
|
||||
func (p *Server) createChildSubagentChatWithOptions(
|
||||
ctx context.Context,
|
||||
parent database.Chat,
|
||||
prompt string,
|
||||
title string,
|
||||
opts childSubagentChatOptions,
|
||||
) (database.Chat, error) {
|
||||
if parent.ParentChatID.Valid {
|
||||
return database.Chat{}, xerrors.New("delegated chats cannot create child subagents")
|
||||
@@ -452,31 +442,251 @@ func (p *Server) createChildSubagentChat(
|
||||
return database.Chat{}, xerrors.New("parent chat model config id is required")
|
||||
}
|
||||
|
||||
child, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
BuildID: parent.BuildID,
|
||||
AgentID: parent.AgentID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: rootChatID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
MCPServerIDs: parent.MCPServerIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("create child chat: %w", err)
|
||||
mcpServerIDs := parent.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
|
||||
labelsJSON, err := json.Marshal(database.StringMap{})
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("marshal labels: %w", err)
|
||||
}
|
||||
childSystemPrompt := SanitizePromptText(opts.systemPrompt)
|
||||
|
||||
var child database.Chat
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
if limitErr := p.checkUsageLimit(ctx, tx, parent.OwnerID); limitErr != nil {
|
||||
return limitErr
|
||||
}
|
||||
|
||||
insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
BuildID: parent.BuildID,
|
||||
AgentID: parent.AgentID,
|
||||
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true},
|
||||
LastModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
Mode: opts.chatMode,
|
||||
Status: database.ChatStatusPending,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
},
|
||||
DynamicTools: pqtype.NullRawMessage{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert child chat: %w", err)
|
||||
}
|
||||
|
||||
deploymentPrompt := p.resolveDeploymentSystemPrompt(ctx)
|
||||
workspaceAwareness := "There is no workspace associated with this chat yet. Create one using the create_workspace tool before using workspace tools like execute, read_file, write_file, etc."
|
||||
if insertedChat.WorkspaceID.Valid {
|
||||
workspaceAwareness = "This chat is attached to a workspace. You can use workspace tools like execute, read_file, write_file, etc."
|
||||
}
|
||||
workspaceAwarenessContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(workspaceAwareness),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal workspace awareness: %w", err)
|
||||
}
|
||||
userContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal initial user content: %w", err)
|
||||
}
|
||||
|
||||
systemParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: insertedChat.ID,
|
||||
}
|
||||
if deploymentPrompt != "" {
|
||||
deploymentContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(deploymentPrompt),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal deployment system prompt: %w", err)
|
||||
}
|
||||
appendChatMessage(&systemParams, newChatMessage(
|
||||
database.ChatMessageRoleSystem,
|
||||
deploymentContent,
|
||||
database.ChatMessageVisibilityModel,
|
||||
parent.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
))
|
||||
}
|
||||
if childSystemPrompt != "" {
|
||||
childSystemPromptContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(childSystemPrompt),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal child system prompt: %w", err)
|
||||
}
|
||||
appendChatMessage(&systemParams, newChatMessage(
|
||||
database.ChatMessageRoleSystem,
|
||||
childSystemPromptContent,
|
||||
database.ChatMessageVisibilityModel,
|
||||
parent.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
))
|
||||
}
|
||||
appendChatMessage(&systemParams, newChatMessage(
|
||||
database.ChatMessageRoleSystem,
|
||||
workspaceAwarenessContent,
|
||||
database.ChatMessageVisibilityModel,
|
||||
parent.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
))
|
||||
if _, err := tx.InsertChatMessages(ctx, systemParams); err != nil {
|
||||
return xerrors.Errorf("insert initial child system messages: %w", err)
|
||||
}
|
||||
|
||||
child = insertedChat
|
||||
|
||||
// Copy persisted context before the initial child prompt so the
|
||||
// child cannot be acquired until its inherited context is in
|
||||
// place. signalWake runs only after commit.
|
||||
copiedContextParts, err := copyParentContextMessages(ctx, p.logger, tx, parent, child)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("copy parent context messages: %w", err)
|
||||
}
|
||||
if err := updateChildLastInjectedContext(ctx, p.logger, tx, child.ID, copiedContextParts); err != nil {
|
||||
return xerrors.Errorf("update child injected context: %w", err)
|
||||
}
|
||||
|
||||
userParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: insertedChat.ID,
|
||||
}
|
||||
appendChatMessage(&userParams, newChatMessage(
|
||||
database.ChatMessageRoleUser,
|
||||
userContent,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
parent.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
).withCreatedBy(parent.OwnerID))
|
||||
if _, err := tx.InsertChatMessages(ctx, userParams); err != nil {
|
||||
return xerrors.Errorf("insert initial child user message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if txErr != nil {
|
||||
return database.Chat{}, xerrors.Errorf("create child chat: %w", txErr)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(child, codersdk.ChatWatchEventKindCreated, nil)
|
||||
p.signalWake()
|
||||
return child, nil
|
||||
}
|
||||
|
||||
// copyParentContextMessages reads persisted context-file and skill
|
||||
// messages from the parent chat and inserts copies into the child
|
||||
// chat. This ensures sub-agents inherit the same instruction and
|
||||
// skill context as their parent without independently re-fetching
|
||||
// from the agent.
|
||||
func copyParentContextMessages(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
store database.Store,
|
||||
parent database.Chat,
|
||||
child database.Chat,
|
||||
) ([]codersdk.ChatMessagePart, error) {
|
||||
parentMessages, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: parent.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get parent messages: %w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
copiedParts []codersdk.ChatMessagePart
|
||||
copiedRole database.ChatMessageRole
|
||||
copiedVisibility database.ChatMessageVisibility
|
||||
copiedVersion int16
|
||||
)
|
||||
for _, msg := range parentMessages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
|
||||
logger.Warn(ctx, "failed to unmarshal parent context message",
|
||||
slog.F("parent_chat_id", parent.ID),
|
||||
slog.F("message_id", msg.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
messageContextParts := FilterContextParts(parts, true)
|
||||
if len(messageContextParts) == 0 {
|
||||
continue
|
||||
}
|
||||
if copiedParts == nil {
|
||||
copiedRole = msg.Role
|
||||
copiedVisibility = msg.Visibility
|
||||
copiedVersion = msg.ContentVersion
|
||||
}
|
||||
copiedParts = append(copiedParts, messageContextParts...)
|
||||
}
|
||||
if len(copiedParts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
copiedParts = FilterContextPartsToLatestAgent(copiedParts)
|
||||
filteredContent, err := chatprompt.MarshalParts(copiedParts)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal filtered context parts: %w", err)
|
||||
}
|
||||
|
||||
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: child.ID,
|
||||
}
|
||||
appendChatMessage(&msgParams, newChatMessage(
|
||||
copiedRole,
|
||||
filteredContent,
|
||||
copiedVisibility,
|
||||
child.LastModelConfigID,
|
||||
copiedVersion,
|
||||
))
|
||||
if _, err := store.InsertChatMessages(ctx, msgParams); err != nil {
|
||||
return nil, xerrors.Errorf("insert context message: %w", err)
|
||||
}
|
||||
|
||||
return copiedParts, nil
|
||||
}
|
||||
|
||||
func updateChildLastInjectedContext(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
store database.Store,
|
||||
chatID uuid.UUID,
|
||||
parts []codersdk.ChatMessagePart,
|
||||
) error {
|
||||
parts = FilterContextPartsToLatestAgent(parts)
|
||||
param, err := BuildLastInjectedContext(parts)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to marshal inherited injected context",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("marshal inherited injected context: %w", err)
|
||||
}
|
||||
if _, err := store.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
logger.Warn(ctx, "failed to update inherited injected context",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("update inherited injected context: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Server) sendSubagentMessage(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
|
||||
@@ -0,0 +1,506 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestCollectContextPartsFromMessagesSkipsSentinelContextFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project/.agents/skills/my-skill/SKILL.md",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "my-skill",
|
||||
SkillDescription: "A test skill",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project/AGENTS.md",
|
||||
ContextFileContent: "# Project instructions",
|
||||
},
|
||||
codersdk.ChatMessageText("ignored"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parts, err := CollectContextPartsFromMessages(context.Background(), slog.Make(), []database.ChatMessage{ //nolint:exhaustruct // Only content fields matter for this unit test.
|
||||
{
|
||||
ID: 1,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: content,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parts, 2)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeSkill, parts[0].Type)
|
||||
require.Equal(t, "my-skill", parts[0].SkillName)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeContextFile, parts[1].Type)
|
||||
require.Equal(t, "/home/coder/project/AGENTS.md", parts[1].ContextFilePath)
|
||||
require.Equal(t, "# Project instructions", parts[1].ContextFileContent)
|
||||
}
|
||||
|
||||
func TestCollectContextPartsFromMessagesKeepsEmptyContextFilesWhenRequested(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "my-skill",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parts, err := CollectContextPartsFromMessages(context.Background(), slog.Make(), []database.ChatMessage{ //nolint:exhaustruct // Only content fields matter for this unit test.
|
||||
{
|
||||
ID: 1,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: content,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parts, 2)
|
||||
require.Equal(t, AgentChatContextSentinelPath, parts[0].ContextFilePath)
|
||||
require.Equal(t, "my-skill", parts[1].SkillName)
|
||||
}
|
||||
|
||||
func TestFilterContextPartsToLatestAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/legacy/AGENTS.md",
|
||||
ContextFileContent: "legacy instructions",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-legacy",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: newAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
got := FilterContextPartsToLatestAgent(parts)
|
||||
require.Len(t, got, 4)
|
||||
require.Equal(t, "/legacy/AGENTS.md", got[0].ContextFilePath)
|
||||
require.Equal(t, "repo-helper-legacy", got[1].SkillName)
|
||||
require.Equal(t, AgentChatContextSentinelPath, got[2].ContextFilePath)
|
||||
require.Equal(t, "repo-helper-new", got[3].SkillName)
|
||||
}
|
||||
|
||||
func createParentChatWithInheritedContext(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
server *Server,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-with-context",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
inheritedParts := []codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project/AGENTS.md",
|
||||
ContextFileContent: "# Project instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/home/coder/project",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "my-skill",
|
||||
SkillDescription: "A test skill",
|
||||
SkillDir: "/home/coder/project/.agents/skills/my-skill",
|
||||
ContextFileSkillMetaFile: "SKILL.md",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project/.agents/skills/my-skill/SKILL.md",
|
||||
},
|
||||
}
|
||||
content, err := json.Marshal(inheritedParts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: parent.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID},
|
||||
ModelConfigID: []uuid.UUID{model.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
|
||||
Content: []string{string(content)},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
return parentChat
|
||||
}
|
||||
|
||||
func assertChildInheritedContext(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
childID uuid.UUID,
|
||||
prompt string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, childChat.LastInjectedContext.Valid)
|
||||
|
||||
var cached []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(childChat.LastInjectedContext.RawMessage, &cached))
|
||||
require.Len(t, cached, 2)
|
||||
|
||||
var sawContextFile bool
|
||||
var sawSkill bool
|
||||
for _, part := range cached {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
sawContextFile = true
|
||||
require.Equal(t, "/home/coder/project/AGENTS.md", part.ContextFilePath)
|
||||
require.Empty(t, part.ContextFileContent)
|
||||
require.Empty(t, part.ContextFileOS)
|
||||
require.Empty(t, part.ContextFileDirectory)
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
sawSkill = true
|
||||
require.Equal(t, "my-skill", part.SkillName)
|
||||
require.Equal(t, "A test skill", part.SkillDescription)
|
||||
require.Empty(t, part.SkillDir)
|
||||
require.Empty(t, part.ContextFileSkillMetaFile)
|
||||
default:
|
||||
t.Fatalf("unexpected cached part type %q", part.Type)
|
||||
}
|
||||
}
|
||||
require.True(t, sawContextFile)
|
||||
require.True(t, sawSkill)
|
||||
|
||||
childMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: childID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var (
|
||||
contextMessageIndexes []int
|
||||
userPromptIndex = -1
|
||||
sawDBAgentsContextFile bool
|
||||
sawDBSkillCompanionContext bool
|
||||
sawDBSkill bool
|
||||
)
|
||||
for i, msg := range childMessages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(msg.Content.RawMessage, &parts))
|
||||
|
||||
if len(parts) == 1 && parts[0].Type == codersdk.ChatMessagePartTypeText && parts[0].Text == prompt {
|
||||
require.Equal(t, database.ChatMessageRoleUser, msg.Role)
|
||||
userPromptIndex = i
|
||||
continue
|
||||
}
|
||||
|
||||
hasInheritedContext := false
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasInheritedContext = true
|
||||
switch part.ContextFilePath {
|
||||
case "/home/coder/project/AGENTS.md":
|
||||
sawDBAgentsContextFile = true
|
||||
require.Equal(t, "# Project instructions", part.ContextFileContent)
|
||||
require.Equal(t, "linux", part.ContextFileOS)
|
||||
require.Equal(t, "/home/coder/project", part.ContextFileDirectory)
|
||||
case "/home/coder/project/.agents/skills/my-skill/SKILL.md":
|
||||
sawDBSkillCompanionContext = true
|
||||
require.Empty(t, part.ContextFileContent)
|
||||
require.Empty(t, part.ContextFileOS)
|
||||
require.Empty(t, part.ContextFileDirectory)
|
||||
default:
|
||||
t.Fatalf("unexpected child inherited context file path %q", part.ContextFilePath)
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
hasInheritedContext = true
|
||||
sawDBSkill = true
|
||||
require.Equal(t, "my-skill", part.SkillName)
|
||||
require.Equal(t, "A test skill", part.SkillDescription)
|
||||
require.Equal(t, "/home/coder/project/.agents/skills/my-skill", part.SkillDir)
|
||||
require.Equal(t, "SKILL.md", part.ContextFileSkillMetaFile)
|
||||
default:
|
||||
t.Fatalf("unexpected child inherited part type %q", part.Type)
|
||||
}
|
||||
}
|
||||
if hasInheritedContext {
|
||||
require.Equal(t, database.ChatMessageRoleUser, msg.Role)
|
||||
contextMessageIndexes = append(contextMessageIndexes, i)
|
||||
}
|
||||
}
|
||||
|
||||
require.NotEmpty(t, contextMessageIndexes)
|
||||
require.NotEqual(t, -1, userPromptIndex)
|
||||
for _, idx := range contextMessageIndexes {
|
||||
require.Less(t, idx, userPromptIndex)
|
||||
}
|
||||
require.True(t, sawDBAgentsContextFile)
|
||||
require.True(t, sawDBSkillCompanionContext)
|
||||
require.True(t, sawDBSkill)
|
||||
}
|
||||
|
||||
func createParentChatWithRotatedInheritedContext(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
server *Server,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-with-rotated-context",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project-old/AGENTS.md",
|
||||
ContextFileContent: "# Old instructions",
|
||||
ContextFileOS: "darwin",
|
||||
ContextFileDirectory: "/home/coder/project-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "old-skill",
|
||||
SkillDescription: "Old skill",
|
||||
SkillDir: "/home/coder/project-old/.agents/skills/old-skill",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
newContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project-new/AGENTS.md",
|
||||
ContextFileContent: "# New instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/home/coder/project-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "new-skill",
|
||||
SkillDescription: "New skill",
|
||||
SkillDir: "/home/coder/project-new/.agents/skills/new-skill",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: parent.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID, user.ID},
|
||||
ModelConfigID: []uuid.UUID{model.ID, model.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleUser},
|
||||
Content: []string{string(oldContent), string(newContent)},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0, 0},
|
||||
OutputTokens: []int64{0, 0},
|
||||
TotalTokens: []int64{0, 0},
|
||||
ReasoningTokens: []int64{0, 0},
|
||||
CacheCreationTokens: []int64{0, 0},
|
||||
CacheReadTokens: []int64{0, 0},
|
||||
ContextLimit: []int64{0, 0},
|
||||
Compressed: []bool{false, false},
|
||||
TotalCostMicros: []int64{0, 0},
|
||||
RuntimeMs: []int64{0, 0},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
return parentChat
|
||||
}
|
||||
|
||||
func TestCreateChildSubagentChatCopiesOnlyLatestAgentContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
parentChat := createParentChatWithRotatedInheritedContext(ctx, t, db, server)
|
||||
|
||||
child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, childChat.LastInjectedContext.Valid)
|
||||
|
||||
var cached []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(childChat.LastInjectedContext.RawMessage, &cached))
|
||||
require.Len(t, cached, 2)
|
||||
require.Equal(t, "/home/coder/project-new/AGENTS.md", cached[0].ContextFilePath)
|
||||
require.Equal(t, "new-skill", cached[1].SkillName)
|
||||
|
||||
childMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: child.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var inherited [][]codersdk.ChatMessagePart
|
||||
for _, msg := range childMessages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(msg.Content.RawMessage, &parts))
|
||||
if len(parts) == 0 || parts[0].Type == codersdk.ChatMessagePartTypeText {
|
||||
continue
|
||||
}
|
||||
inherited = append(inherited, parts)
|
||||
}
|
||||
require.Len(t, inherited, 1)
|
||||
require.Len(t, inherited[0], 2)
|
||||
require.Equal(t, "/home/coder/project-new/AGENTS.md", inherited[0][0].ContextFilePath)
|
||||
require.Equal(t, "# New instructions", inherited[0][0].ContextFileContent)
|
||||
require.Equal(t, "new-skill", inherited[0][1].SkillName)
|
||||
}
|
||||
|
||||
func TestCreateChildSubagentChatUpdatesInheritedLastInjectedContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
parentChat := createParentChatWithInheritedContext(ctx, t, db, server)
|
||||
|
||||
child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assertChildInheritedContext(ctx, t, db, child.ID, "inspect bindings")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgentInheritsContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
parentChat := createParentChatWithInheritedContext(ctx, t, db, server)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
require.NotNil(t, tool)
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-context",
|
||||
Name: "spawn_computer_use_agent",
|
||||
Input: `{"prompt":"inspect bindings"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError, "expected success but got: %s", resp.Content)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
childIDStr, ok := result["chat_id"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
childID, err := uuid.Parse(childIDStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, childChat.Mode.Valid)
|
||||
require.Equal(t, database.ChatModeComputerUse, childChat.Mode.ChatMode)
|
||||
|
||||
assertChildInheritedContext(ctx, t, db, childID, "inspect bindings")
|
||||
}
|
||||
@@ -892,3 +892,66 @@ func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*Reinitialization
|
||||
return &reinitEvent, nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddChatContextRequest is the request body for adding chat context.
|
||||
type AddChatContextRequest struct {
|
||||
// ChatID optionally identifies the chat to add context to.
|
||||
// If empty, auto-detection is used (CODER_CHAT_ID env, the
|
||||
// only active chat, or the only top-level active chat for this
|
||||
// agent).
|
||||
ChatID uuid.UUID `json:"chat_id,omitempty"`
|
||||
// Parts are the context-file and skill parts to add.
|
||||
Parts []codersdk.ChatMessagePart `json:"parts"`
|
||||
}
|
||||
|
||||
// AddChatContextResponse is the response for adding chat context.
|
||||
type AddChatContextResponse struct {
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// ClearChatContextRequest is the request body for clearing chat context.
|
||||
type ClearChatContextRequest struct {
|
||||
// ChatID optionally identifies the chat to clear context from.
|
||||
// If empty, auto-detection is used (CODER_CHAT_ID env, the
|
||||
// only active chat, or the only top-level active chat for this
|
||||
// agent).
|
||||
ChatID uuid.UUID `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
// ClearChatContextResponse is the response for clearing chat context.
|
||||
type ClearChatContextResponse struct {
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
}
|
||||
|
||||
// AddChatContext adds context-file and skill parts to an active chat.
|
||||
func (c *Client) AddChatContext(ctx context.Context, req AddChatContextRequest) (AddChatContextResponse, error) {
|
||||
res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/experimental/chat-context", req)
|
||||
if err != nil {
|
||||
return AddChatContextResponse{}, xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return AddChatContextResponse{}, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var resp AddChatContextResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// ClearChatContext soft-deletes context-file and skill messages from an active chat.
|
||||
func (c *Client) ClearChatContext(ctx context.Context, req ClearChatContextRequest) (ClearChatContextResponse, error) {
|
||||
res, err := c.SDK.Request(ctx, http.MethodDelete, "/api/v2/workspaceagents/me/experimental/chat-context", req)
|
||||
if err != nil {
|
||||
return ClearChatContextResponse{}, xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ClearChatContextResponse{}, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var resp ClearChatContextResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
@@ -376,7 +376,7 @@ func ProtoFromLog(log Log) (*proto.Log, error) {
|
||||
}
|
||||
return &proto.Log{
|
||||
CreatedAt: timestamppb.New(log.CreatedAt),
|
||||
Output: strings.ToValidUTF8(log.Output, "❌"),
|
||||
Output: SanitizeLogOutput(log.Output),
|
||||
Level: proto.Log_Level(lvl),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -229,7 +229,7 @@ func TestLogSender_SkipHugeLog(t *testing.T) {
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
func TestLogSender_InvalidUTF8(t *testing.T) {
|
||||
func TestLogSender_SanitizeOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||
ctx, cancel := context.WithCancel(testCtx)
|
||||
@@ -243,7 +243,7 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
|
||||
uut.Enqueue(ls1,
|
||||
Log{
|
||||
CreatedAt: t0,
|
||||
Output: "test log 0, src 1\xc3\x28",
|
||||
Output: "test log 0, src 1\x00\xc3\x28",
|
||||
Level: codersdk.LogLevelInfo,
|
||||
},
|
||||
Log{
|
||||
@@ -260,10 +260,10 @@ func TestLogSender_InvalidUTF8(t *testing.T) {
|
||||
|
||||
req := testutil.TryReceive(ctx, t, fDest.reqs)
|
||||
require.NotNil(t, req)
|
||||
require.Len(t, req.Logs, 2, "it should sanitize invalid UTF-8, but still send")
|
||||
// the 0xc3, 0x28 is an invalid 2-byte sequence in UTF-8. The sanitizer replaces 0xc3 with ❌, and then
|
||||
// interprets 0x28 as a 1-byte sequence "("
|
||||
require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput())
|
||||
require.Len(t, req.Logs, 2, "it should sanitize invalid output, but still send")
|
||||
// The sanitizer replaces the NUL byte and invalid UTF-8 with ❌ while
|
||||
// preserving the valid "(" byte that follows 0xc3.
|
||||
require.Equal(t, "test log 0, src 1❌❌(", req.Logs[0].GetOutput())
|
||||
require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel())
|
||||
require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput())
|
||||
require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel())
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package agentsdk
|
||||
|
||||
import "strings"
|
||||
|
||||
// SanitizeLogOutput replaces invalid UTF-8 and NUL characters in log output.
|
||||
// Invalid UTF-8 cannot be transported in protobuf string fields, and PostgreSQL
|
||||
// rejects NUL bytes in text columns.
|
||||
func SanitizeLogOutput(s string) string {
|
||||
s = strings.ToValidUTF8(s, "❌")
|
||||
return strings.ReplaceAll(s, "\x00", "❌")
|
||||
}
|
||||
@@ -17,6 +17,54 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSanitizeLogOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
in: "hello world",
|
||||
want: "hello world",
|
||||
},
|
||||
{
|
||||
name: "invalid utf8",
|
||||
in: "test log\xc3\x28",
|
||||
want: "test log❌(",
|
||||
},
|
||||
{
|
||||
name: "nul byte",
|
||||
in: "before\x00after",
|
||||
want: "before❌after",
|
||||
},
|
||||
{
|
||||
name: "invalid utf8 and nul byte",
|
||||
in: "before\x00middle\xc3\x28after",
|
||||
want: "before❌middle❌(after",
|
||||
},
|
||||
{
|
||||
name: "nul byte at edges",
|
||||
in: "\x00middle\x00",
|
||||
want: "❌middle❌",
|
||||
},
|
||||
{
|
||||
name: "invalid utf8 at edges",
|
||||
in: "\xc3middle\xc3",
|
||||
want: "❌middle❌",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.want, agentsdk.SanitizeLogOutput(tt.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupLogsWriter_Write(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -127,6 +127,8 @@ type AIBridgeThread struct {
|
||||
Prompt *string `json:"prompt,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Provider string `json:"provider"`
|
||||
CredentialKind string `json:"credential_kind"`
|
||||
CredentialHint string `json:"credential_hint"`
|
||||
StartedAt time.Time `json:"started_at" format:"date-time"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty" format:"date-time"`
|
||||
TokenUsage AIBridgeSessionThreadsTokenUsage `json:"token_usage"`
|
||||
|
||||
+13
-97
@@ -1130,11 +1130,6 @@ type ChatStreamEvent struct {
|
||||
ActionRequired *ChatStreamActionRequired `json:"action_required,omitempty"`
|
||||
}
|
||||
|
||||
type chatStreamEnvelope struct {
|
||||
Type ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCostSummaryOptions are optional query parameters for GetChatCostSummary.
|
||||
type ChatCostSummaryOptions struct {
|
||||
StartDate time.Time
|
||||
@@ -1987,8 +1982,8 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
|
||||
}()
|
||||
|
||||
for {
|
||||
var envelope chatStreamEnvelope
|
||||
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
|
||||
var batch []ChatStreamEvent
|
||||
if err := wsjson.Read(streamCtx, conn, &batch); err != nil {
|
||||
if streamCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
@@ -2005,61 +2000,10 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
|
||||
return
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case ServerSentEventTypePing:
|
||||
continue
|
||||
case ServerSentEventTypeData:
|
||||
var batch []ChatStreamEvent
|
||||
decodeErr := json.Unmarshal(envelope.Data, &batch)
|
||||
if decodeErr == nil {
|
||||
for _, streamedEvent := range batch {
|
||||
if !send(streamedEvent) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
{
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: fmt.Sprintf(
|
||||
"decode chat stream event batch: %v",
|
||||
decodeErr,
|
||||
),
|
||||
},
|
||||
})
|
||||
for _, event := range batch {
|
||||
if !send(event) {
|
||||
return
|
||||
}
|
||||
case ServerSentEventTypeError:
|
||||
message := "chat stream returned an error"
|
||||
if len(envelope.Data) > 0 {
|
||||
var response Response
|
||||
if err := json.Unmarshal(envelope.Data, &response); err == nil {
|
||||
message = formatChatStreamResponseError(response)
|
||||
} else {
|
||||
trimmed := strings.TrimSpace(string(envelope.Data))
|
||||
if trimmed != "" {
|
||||
message = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
return
|
||||
default:
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: fmt.Sprintf("unknown chat stream event type %q", envelope.Type),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -2098,8 +2042,8 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv
|
||||
}()
|
||||
|
||||
for {
|
||||
var envelope chatStreamEnvelope
|
||||
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
|
||||
var event ChatWatchEvent
|
||||
if err := wsjson.Read(streamCtx, conn, &event); err != nil {
|
||||
if streamCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
@@ -2110,23 +2054,10 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv
|
||||
return
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case ServerSentEventTypePing:
|
||||
continue
|
||||
case ServerSentEventTypeData:
|
||||
var event ChatWatchEvent
|
||||
if err := json.Unmarshal(envelope.Data, &event); err != nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
case events <- event:
|
||||
}
|
||||
case ServerSentEventTypeError:
|
||||
return
|
||||
default:
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
case events <- event:
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -2478,27 +2409,12 @@ func (c *ExperimentalClient) GetChatsByWorkspace(ctx context.Context, workspaceI
|
||||
return result, json.NewDecoder(res.Body).Decode(&result)
|
||||
}
|
||||
|
||||
func formatChatStreamResponseError(response Response) string {
|
||||
message := strings.TrimSpace(response.Message)
|
||||
detail := strings.TrimSpace(response.Detail)
|
||||
switch {
|
||||
case message == "" && detail == "":
|
||||
return "chat stream returned an error"
|
||||
case message == "":
|
||||
return detail
|
||||
case detail == "":
|
||||
return message
|
||||
default:
|
||||
return fmt.Sprintf("%s: %s", message, detail)
|
||||
}
|
||||
}
|
||||
|
||||
// PRInsightsResponse is the response from the PR insights endpoint.
|
||||
type PRInsightsResponse struct {
|
||||
Summary PRInsightsSummary `json:"summary"`
|
||||
TimeSeries []PRInsightsTimeSeriesEntry `json:"time_series"`
|
||||
ByModel []PRInsightsModelBreakdown `json:"by_model"`
|
||||
RecentPRs []PRInsightsPullRequest `json:"recent_prs"`
|
||||
Summary PRInsightsSummary `json:"summary"`
|
||||
TimeSeries []PRInsightsTimeSeriesEntry `json:"time_series"`
|
||||
ByModel []PRInsightsModelBreakdown `json:"by_model"`
|
||||
PullRequests []PRInsightsPullRequest `json:"recent_prs"`
|
||||
}
|
||||
|
||||
// PRInsightsSummary contains aggregate PR metrics for a time period,
|
||||
|
||||
@@ -75,6 +75,49 @@ func (req *OAuth2ClientRegistrationRequest) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRedirectURIScheme reports whether the callback URL's scheme is
|
||||
// safe to use as a redirect target. It returns an error when the scheme
|
||||
// is empty, an unsupported URN, or one of the schemes that are dangerous
|
||||
// in browser/HTML contexts (javascript, data, file, ftp).
|
||||
//
|
||||
// Legitimate custom schemes for native apps (e.g. vscode://, jetbrains://)
|
||||
// are allowed.
|
||||
// ValidateRedirectURIScheme reports whether the callback URL's scheme is
|
||||
// safe to use as a redirect target. It returns an error when the scheme
|
||||
// is empty, an unsupported URN, or one of the schemes that are dangerous
|
||||
// in browser/HTML contexts (javascript, data, file, ftp).
|
||||
//
|
||||
// Legitimate custom schemes for native apps (e.g. vscode://, jetbrains://)
|
||||
// are allowed.
|
||||
func ValidateRedirectURIScheme(u *url.URL) error {
|
||||
return validateScheme(u)
|
||||
}
|
||||
|
||||
func validateScheme(u *url.URL) error {
|
||||
if u.Scheme == "" {
|
||||
return xerrors.New("redirect URI must have a scheme")
|
||||
}
|
||||
|
||||
// Handle special URNs (RFC 6749 section 3.1.2.1).
|
||||
if u.Scheme == "urn" {
|
||||
if u.String() == "urn:ietf:wg:oauth:2.0:oob" {
|
||||
return nil
|
||||
}
|
||||
return xerrors.New("redirect URI uses unsupported URN scheme")
|
||||
}
|
||||
|
||||
// Block dangerous schemes for security (not allowed by RFCs
|
||||
// for OAuth2).
|
||||
dangerousSchemes := []string{"javascript", "data", "file", "ftp"}
|
||||
for _, dangerous := range dangerousSchemes {
|
||||
if strings.EqualFold(u.Scheme, dangerous) {
|
||||
return xerrors.Errorf("redirect URI uses dangerous scheme %s which is not allowed", dangerous)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRedirectURIs validates redirect URIs according to RFC 7591, 8252
|
||||
func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod) error {
|
||||
if len(uris) == 0 {
|
||||
@@ -91,27 +134,14 @@ func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndp
|
||||
return xerrors.Errorf("redirect URI at index %d is not a valid URL: %w", i, err)
|
||||
}
|
||||
|
||||
// Validate schemes according to RFC requirements
|
||||
if uri.Scheme == "" {
|
||||
return xerrors.Errorf("redirect URI at index %d must have a scheme", i)
|
||||
if err := validateScheme(uri); err != nil {
|
||||
return xerrors.Errorf("redirect URI at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
// Handle special URNs (RFC 6749 section 3.1.2.1)
|
||||
// The urn:ietf:wg:oauth:2.0:oob scheme passed validation
|
||||
// above but needs no further checks.
|
||||
if uri.Scheme == "urn" {
|
||||
// Allow the out-of-band redirect URI for native apps
|
||||
if uriStr == "urn:ietf:wg:oauth:2.0:oob" {
|
||||
continue // This is valid for native apps
|
||||
}
|
||||
// Other URNs are not standard for OAuth2
|
||||
return xerrors.Errorf("redirect URI at index %d uses unsupported URN scheme", i)
|
||||
}
|
||||
|
||||
// Block dangerous schemes for security (not allowed by RFCs for OAuth2)
|
||||
dangerousSchemes := []string{"javascript", "data", "file", "ftp"}
|
||||
for _, dangerous := range dangerousSchemes {
|
||||
if strings.EqualFold(uri.Scheme, dangerous) {
|
||||
return xerrors.Errorf("redirect URI at index %d uses dangerous scheme %s which is not allowed", i, dangerous)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine if this is a public client based on token endpoint auth method
|
||||
|
||||
@@ -143,13 +143,14 @@ type ProvisionerJobInput struct {
|
||||
|
||||
// ProvisionerJobMetadata contains metadata for the job.
|
||||
type ProvisionerJobMetadata struct {
|
||||
TemplateVersionName string `json:"template_version_name" table:"template version name"`
|
||||
TemplateID uuid.UUID `json:"template_id" format:"uuid" table:"template id"`
|
||||
TemplateName string `json:"template_name" table:"template name"`
|
||||
TemplateDisplayName string `json:"template_display_name" table:"template display name"`
|
||||
TemplateIcon string `json:"template_icon" table:"template icon"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid" table:"workspace id"`
|
||||
WorkspaceName string `json:"workspace_name,omitempty" table:"workspace name"`
|
||||
TemplateVersionName string `json:"template_version_name" table:"template version name"`
|
||||
TemplateID uuid.UUID `json:"template_id" format:"uuid" table:"template id"`
|
||||
TemplateName string `json:"template_name" table:"template name"`
|
||||
TemplateDisplayName string `json:"template_display_name" table:"template display name"`
|
||||
TemplateIcon string `json:"template_icon" table:"template icon"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid" table:"workspace id"`
|
||||
WorkspaceName string `json:"workspace_name,omitempty" table:"workspace name"`
|
||||
WorkspaceBuildTransition WorkspaceTransition `json:"workspace_build_transition,omitempty" table:"workspace build transition"`
|
||||
}
|
||||
|
||||
// ProvisionerJobType represents the type of job.
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserSecret represents a user secret's metadata. The secret value
|
||||
// is never included in API responses.
|
||||
type UserSecret struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
EnvName string `json:"env_name"`
|
||||
FilePath string `json:"file_path"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// CreateUserSecretRequest is the payload for creating a new user
|
||||
// secret. Name and Value are required. All other fields are optional
|
||||
// and default to empty string.
|
||||
type CreateUserSecretRequest struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Description string `json:"description,omitempty"`
|
||||
EnvName string `json:"env_name,omitempty"`
|
||||
FilePath string `json:"file_path,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateUserSecretRequest is the payload for partially updating a
|
||||
// user secret. At least one field must be non-nil. Pointer fields
|
||||
// distinguish "not sent" (nil) from "set to empty string" (pointer
|
||||
// to empty string).
|
||||
type UpdateUserSecretRequest struct {
|
||||
Value *string `json:"value,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
EnvName *string `json:"env_name,omitempty"`
|
||||
FilePath *string `json:"file_path,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Client) CreateUserSecret(ctx context.Context, user string, req CreateUserSecretRequest) (UserSecret, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/users/%s/secrets", user), req)
|
||||
if err != nil {
|
||||
return UserSecret{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return UserSecret{}, ReadBodyAsError(res)
|
||||
}
|
||||
var secret UserSecret
|
||||
return secret, json.NewDecoder(res.Body).Decode(&secret)
|
||||
}
|
||||
|
||||
func (c *Client) UserSecrets(ctx context.Context, user string) ([]UserSecret, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s/secrets", user), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var secrets []UserSecret
|
||||
return secrets, json.NewDecoder(res.Body).Decode(&secrets)
|
||||
}
|
||||
|
||||
func (c *Client) UserSecretByName(ctx context.Context, user string, name string) (UserSecret, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s/secrets/%s", user, name), nil)
|
||||
if err != nil {
|
||||
return UserSecret{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return UserSecret{}, ReadBodyAsError(res)
|
||||
}
|
||||
var secret UserSecret
|
||||
return secret, json.NewDecoder(res.Body).Decode(&secret)
|
||||
}
|
||||
|
||||
func (c *Client) UpdateUserSecret(ctx context.Context, user string, name string, req UpdateUserSecretRequest) (UserSecret, error) {
|
||||
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/v2/users/%s/secrets/%s", user, name), req)
|
||||
if err != nil {
|
||||
return UserSecret{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return UserSecret{}, ReadBodyAsError(res)
|
||||
}
|
||||
var secret UserSecret
|
||||
return secret, json.NewDecoder(res.Body).Decode(&secret)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteUserSecret(ctx context.Context, user string, name string) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/v2/users/%s/secrets/%s", user, name), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// UserSecretEnvValidationOptions controls deployment-aware behavior
|
||||
// in environment variable name validation.
|
||||
type UserSecretEnvValidationOptions struct {
|
||||
// AIGatewayEnabled indicates that the deployment has AI Gateway
|
||||
// configured. When true, AI Gateway environment variables
|
||||
// (OPENAI_API_KEY, etc.) are reserved to prevent conflicts.
|
||||
AIGatewayEnabled bool
|
||||
}
|
||||
|
||||
var (
|
||||
// posixEnvNameRegex matches valid POSIX environment variable names:
|
||||
// must start with a letter or underscore, followed by letters,
|
||||
// digits, or underscores.
|
||||
posixEnvNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// reservedEnvNames are system environment variables that must not
|
||||
// be overridden by user secrets. This list is intentionally
|
||||
// aggressive because it is easier to remove entries later than
|
||||
// to add them after users have already created conflicting
|
||||
// secrets.
|
||||
reservedEnvNames = map[string]struct{}{
|
||||
// Core POSIX/login variables. Overriding these breaks
|
||||
// basic shell and session behavior.
|
||||
"PATH": {},
|
||||
"HOME": {},
|
||||
"SHELL": {},
|
||||
"USER": {},
|
||||
"LOGNAME": {},
|
||||
"PWD": {},
|
||||
"OLDPWD": {},
|
||||
|
||||
// Locale and terminal. Agents and IDEs depend on these
|
||||
// being set correctly by the system.
|
||||
"LANG": {},
|
||||
"TERM": {},
|
||||
|
||||
// Shell behavior. Overriding these can silently break
|
||||
// word splitting, directory resolution, and script
|
||||
// execution in every shell session and agent script.
|
||||
"IFS": {},
|
||||
"CDPATH": {},
|
||||
|
||||
// Shell startup files. ENV is sourced by POSIX sh for
|
||||
// interactive shells; BASH_ENV is sourced by bash for
|
||||
// every non-interactive invocation (scripts, subshells).
|
||||
// Allowing users to set these would inject arbitrary
|
||||
// code into every shell and script in the workspace.
|
||||
"ENV": {},
|
||||
"BASH_ENV": {},
|
||||
|
||||
// Temp directories. Overriding these is a security risk
|
||||
// (symlink attacks, world-readable paths).
|
||||
"TMPDIR": {},
|
||||
"TMP": {},
|
||||
"TEMP": {},
|
||||
|
||||
// Host identity.
|
||||
"HOSTNAME": {},
|
||||
|
||||
// SSH session variables. The Coder agent sets
|
||||
// SSH_AUTH_SOCK in agentssh.go; the others are set by
|
||||
// sshd and should never be faked.
|
||||
"SSH_AUTH_SOCK": {},
|
||||
"SSH_CLIENT": {},
|
||||
"SSH_CONNECTION": {},
|
||||
"SSH_TTY": {},
|
||||
|
||||
// Editor/pager. The Coder agent sets these so that git
|
||||
// operations inside workspaces work non-interactively.
|
||||
"EDITOR": {},
|
||||
"VISUAL": {},
|
||||
"PAGER": {},
|
||||
|
||||
// IDE integration. The agent sets these for code-server
|
||||
// and VS Code Remote proxying.
|
||||
"VSCODE_PROXY_URI": {},
|
||||
"CS_DISABLE_GETTING_STARTED_OVERRIDE": {},
|
||||
|
||||
// XDG base directories. Overriding these redirects
|
||||
// config, cache, and runtime data for every tool in the
|
||||
// workspace.
|
||||
"XDG_RUNTIME_DIR": {},
|
||||
"XDG_CONFIG_HOME": {},
|
||||
"XDG_DATA_HOME": {},
|
||||
"XDG_CACHE_HOME": {},
|
||||
"XDG_STATE_HOME": {},
|
||||
|
||||
// OIDC token. The Coder agent injects a short-lived
|
||||
// OIDC token for cloud auth flows (e.g. GCP workload
|
||||
// identity). Overriding it could break provisioner and
|
||||
// agent authentication.
|
||||
"OIDC_TOKEN": {},
|
||||
}
|
||||
|
||||
// aiGatewayReservedEnvNames are reserved only when AI Gateway
|
||||
// is enabled on the deployment. When AI Gateway is disabled,
|
||||
// users may legitimately want to inject their own API keys
|
||||
// via secrets.
|
||||
aiGatewayReservedEnvNames = map[string]struct{}{
|
||||
"OPENAI_API_KEY": {},
|
||||
"OPENAI_BASE_URL": {},
|
||||
"ANTHROPIC_AUTH_TOKEN": {},
|
||||
"ANTHROPIC_BASE_URL": {},
|
||||
}
|
||||
|
||||
// reservedEnvPrefixes are namespace prefixes where every
|
||||
// variable in the family is reserved. Checked after the
|
||||
// exact-name map. The CODER / CODER_* namespace is handled
|
||||
// separately with its own error message (see below).
|
||||
reservedEnvPrefixes = []string{
|
||||
// The Coder agent sets GIT_SSH_COMMAND, GIT_ASKPASS,
|
||||
// GIT_AUTHOR_*, GIT_COMMITTER_*, and several others.
|
||||
// Blocking the entire GIT_* namespace avoids an arms
|
||||
// race with new git env vars.
|
||||
"GIT_",
|
||||
|
||||
// Locale variables. LC_ALL, LC_CTYPE, LC_MESSAGES,
|
||||
// etc. control character encoding, sorting, and
|
||||
// formatting. Overriding them can break text
|
||||
// processing in agents and IDEs.
|
||||
"LC_",
|
||||
|
||||
// Dynamic linker variables. Allowing users to set
|
||||
// these would let a secret inject arbitrary shared
|
||||
// libraries into every process in the workspace.
|
||||
"LD_",
|
||||
"DYLD_",
|
||||
}
|
||||
)
|
||||
|
||||
// UserSecretEnvNameValid validates an environment variable name for
|
||||
// a user secret. Empty string is allowed (means no env injection).
|
||||
// The opts parameter controls deployment-aware checks such as AI
|
||||
// bridge variable reservation.
|
||||
func UserSecretEnvNameValid(s string, opts UserSecretEnvValidationOptions) error {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !posixEnvNameRegex.MatchString(s) {
|
||||
return xerrors.New("must start with a letter or underscore, followed by letters, digits, or underscores")
|
||||
}
|
||||
|
||||
upper := strings.ToUpper(s)
|
||||
|
||||
if _, ok := reservedEnvNames[upper]; ok {
|
||||
return xerrors.Errorf("%s is a reserved environment variable name", upper)
|
||||
}
|
||||
|
||||
if upper == "CODER" || strings.HasPrefix(upper, "CODER_") {
|
||||
return xerrors.New("environment variable names starting with CODER_ are reserved for internal use")
|
||||
}
|
||||
|
||||
for _, prefix := range reservedEnvPrefixes {
|
||||
if strings.HasPrefix(upper, prefix) {
|
||||
return xerrors.Errorf("environment variables starting with %s are reserved", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.AIGatewayEnabled {
|
||||
if _, ok := aiGatewayReservedEnvNames[upper]; ok {
|
||||
return xerrors.Errorf("%s is reserved when AI Gateway is enabled", upper)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserSecretFilePathValid validates a file path for a user secret.
|
||||
// Empty string is allowed (means no file injection). Non-empty paths
|
||||
// must start with ~/ or /.
|
||||
func UserSecretFilePathValid(s string) error {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(s, "~/") || strings.HasPrefix(s, "/") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.New("file path must start with ~/ or /")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user